From b04cc917735d7f80225a314fac746ed761c327fc Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Wed, 13 Jan 2021 05:26:00 +0000 Subject: [PATCH 1/9] Implement cuda kernel for index_sample. --- paddle/fluid/operators/index_sample_op.cu | 213 +++++++++++++++++++++- 1 file changed, 205 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 1dc7a128edc47..28d2d830d3512 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -12,18 +12,215 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_sample_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void index_kernel(const IndexT* p_index, const T* p_input, + T* p_output, size_t stride_index, + size_t stride_input, size_t height) { + int ix = blockDim.x * blockIdx.x + threadIdx.x; + int iy = blockDim.y * blockIdx.y + threadIdx.y; + int tid = iy * stride_index + ix; + int tid_x = iy * stride_input + ix; + int tid_y = iy * stride_index + ix; + + if (ix < stride_index & iy < height) { + IndexT idx = p_index[tid]; + p_output[tid_y] = p_input[tid_x - ix + idx]; + } +} + +template +__global__ void index_kernel_grad(const IndexT* p_index, const T* p_input, + T* p_output, size_t stride_index, + size_t stride_input, size_t height) { + int ix = blockDim.x * blockIdx.x + threadIdx.x; + int iy = blockDim.y * blockIdx.y + threadIdx.y; + int tid = iy * stride_index + ix; + int tid_y = iy * stride_input + ix; + + if (ix < stride_index & iy < height) { + IndexT idx = p_index[tid]; + p_output[tid_y - ix + idx] += p_input[tid]; + } +} + +template +class IndexSampleCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* index = ctx.Input("Index"); + auto* output = ctx.Output("Out"); + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); + + auto ComputeBlockSize = [](int col) { + if (col > 512) + return 1024; + else if (col > 256) + return 512; + else if (col > 128) + return 256; + else if (col > 64) + return 128; + else if (col > 32) + return 64; + else if (col > 16) + return 32; + else if (col > 8) + return 16; + else + return 8; + }; + const auto* in_data = input->data(); + auto* out_data = output->mutable_data(ctx.GetPlace()); + auto stream = + ctx.template device_context().stream(); + + auto input_dim = input->dims(); + auto index_dim = index->dims(); + size_t batch_size = input_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + + auto block_width = ComputeBlockSize(index_length); + int block_height = + ComputeBlockSize(index_length * batch_size) / block_width; + + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_kernel<<>>( + index_data, in_data, out_data, index_length, input_length, + batch_size); + } else if (index_type == framework::proto::VarType::INT32) { + const int* index_data = index->data(); + index_kernel<<>>( + index_data, in_data, out_data, index_length, input_length, + batch_size); + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } +}; + +template +class IndexSampleGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* index = ctx.Input("Index"); + + const auto* output_grad_data = output_grad->data(); + auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); + + auto ComputeBlockSize = [](int col) { + if (col > 512) + return 1024; + else if (col > 256) + return 512; + else if (col > 128) + return 256; + else if (col > 64) + return 128; + else if (col > 32) + return 64; + else if (col > 16) + return 32; + else if (col > 8) + return 16; + else + return 8; + }; + auto stream = + ctx.template device_context().stream(); + + auto input_dim = input_grad->dims(); + auto index_dim = index->dims(); + size_t batch_size = index_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + + auto block_width = ComputeBlockSize(index_length); + auto block_height = + ComputeBlockSize(index_length * batch_size) / block_width; + + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_kernel_grad<<>>( + index_data, output_grad_data, input_grad_data, index_length, + input_length, batch_size); + } else if (index_type == framework::proto::VarType::INT32) { + const int* index_data = index->data(); + index_kernel_grad<<>>( + index_data, output_grad_data, input_grad_data, index_length, + input_length, batch_size); + } + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( index_sample, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel, - ops::IndexSampleKernel); + ops::IndexSampleCUDAKernel, + ops::IndexSampleCUDAKernel, + ops::IndexSampleCUDAKernel, + ops::IndexSampleCUDAKernel); REGISTER_OP_CUDA_KERNEL( index_sample_grad, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel, - ops::IndexSampleGradKernel); + ops::IndexSampleGradCUDAKernel, + ops::IndexSampleGradCUDAKernel, + ops::IndexSampleGradCUDAKernel, + ops::IndexSampleGradCUDAKernel); From 866dd2853b23359173be6c75f43cd5ecee788bca Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 14 Jan 2021 13:11:02 +0000 Subject: [PATCH 2/9] [WIP]: index_sample grad basic kernel realization. Needed to be optimized. --- paddle/fluid/operators/index_sample_op.cu | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 28d2d830d3512..d064f5cbb76a9 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -40,17 +40,25 @@ __global__ void index_kernel(const IndexT* p_index, const T* p_input, } template -__global__ void index_kernel_grad(const IndexT* p_index, const T* p_input, - T* p_output, size_t stride_index, +__global__ void index_kernel_grad(const IndexT_* p_index, T* p_input, + const T* p_output, size_t stride_index, size_t stride_input, size_t height) { + extern __shared__ T s_buf[]; int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; int tid = iy * stride_index + ix; int tid_y = iy * stride_input + ix; + s_buf[tid_y] = p_input[tid_y]; + s_buf[tid_y] = 0; if (ix < stride_index & iy < height) { - IndexT idx = p_index[tid]; - p_output[tid_y - ix + idx] += p_input[tid]; + for (int i = 0; i < stride_index; ++i) { + if (ix == i) { + IndexT idx = p_index[tid]; + s_buf[tid_y - ix + idx] += p_output[tid]; + } + } + p_input[tid_y] = s_buf[tid_y]; } } @@ -178,6 +186,7 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { auto stream = ctx.template device_context().stream(); + auto input_num = input_grad->numel(); auto input_dim = input_grad->dims(); auto index_dim = index->dims(); size_t batch_size = index_dim[0]; @@ -194,12 +203,14 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel_grad<<>>( + index_kernel_grad< + T, int64_t><<>>( index_data, output_grad_data, input_grad_data, index_length, input_length, batch_size); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel_grad<<>>( + index_kernel_grad< + T, int><<>>( index_data, output_grad_data, input_grad_data, index_length, input_length, batch_size); } From d18b4bfb41188c8e27462ccb2fdddbf969a925df Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Fri, 15 Jan 2021 09:15:41 +0000 Subject: [PATCH 3/9] [WIP]: Very basic grad_function which passes the Ctest, in need of optimization desperately. --- paddle/fluid/operators/index_sample_op.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index d064f5cbb76a9..519e6ca1ab943 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -40,25 +40,23 @@ __global__ void index_kernel(const IndexT* p_index, const T* p_input, } template -__global__ void index_kernel_grad(const IndexT_* p_index, T* p_input, +__global__ void index_kernel_grad(const IndexT* p_index, T* p_input, const T* p_output, size_t stride_index, size_t stride_input, size_t height) { - extern __shared__ T s_buf[]; int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; int tid = iy * stride_index + ix; int tid_y = iy * stride_input + ix; - s_buf[tid_y] = p_input[tid_y]; - s_buf[tid_y] = 0; if (ix < stride_index & iy < height) { for (int i = 0; i < stride_index; ++i) { if (ix == i) { IndexT idx = p_index[tid]; - s_buf[tid_y - ix + idx] += p_output[tid]; + T tmp1 = p_output[tid]; + T tmp2 = p_input[tid_y - ix + idx]; + p_input[tid_y - ix + idx] = tmp1 + tmp2; } } - p_input[tid_y] = s_buf[tid_y]; } } @@ -201,17 +199,18 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); + platform::GpuMemsetAsync(input_grad_data, 0, + sizeof(T) * input_length * batch_size, stream); + if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel_grad< - T, int64_t><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel_grad< - T, int><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); From 170e5e163f77bd7be82cb31c92e7bd1659fddc9a Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Fri, 15 Jan 2021 09:15:41 +0000 Subject: [PATCH 4/9] [WIP]: Very basic grad_function which passes the Ctest, in need of optimization desperately. --- paddle/fluid/operators/index_sample_op.cu | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index d064f5cbb76a9..519e6ca1ab943 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -40,25 +40,23 @@ __global__ void index_kernel(const IndexT* p_index, const T* p_input, } template -__global__ void index_kernel_grad(const IndexT_* p_index, T* p_input, +__global__ void index_kernel_grad(const IndexT* p_index, T* p_input, const T* p_output, size_t stride_index, size_t stride_input, size_t height) { - extern __shared__ T s_buf[]; int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; int tid = iy * stride_index + ix; int tid_y = iy * stride_input + ix; - s_buf[tid_y] = p_input[tid_y]; - s_buf[tid_y] = 0; if (ix < stride_index & iy < height) { for (int i = 0; i < stride_index; ++i) { if (ix == i) { IndexT idx = p_index[tid]; - s_buf[tid_y - ix + idx] += p_output[tid]; + T tmp1 = p_output[tid]; + T tmp2 = p_input[tid_y - ix + idx]; + p_input[tid_y - ix + idx] = tmp1 + tmp2; } } - p_input[tid_y] = s_buf[tid_y]; } } @@ -201,17 +199,18 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); + platform::GpuMemsetAsync(input_grad_data, 0, + sizeof(T) * input_length * batch_size, stream); + if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel_grad< - T, int64_t><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel_grad< - T, int><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); From 98a9af7cc7b0c4d880fb90fe1ecf42286a0594c4 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Fri, 22 Jan 2021 09:54:38 +0000 Subject: [PATCH 5/9] [MODIFIED]: Changing the grad kernel into atomicAdd style, which would inevitably increase thread-safety once calcalating the backward step of index_sample OP, and one special CUDA kernel considering the condition that each line of index array only contains 1 element. Besides, thread-deployment in block was 2-demensions. --- paddle/fluid/operators/index_sample_op.cu | 129 ++++++++---------- .../tests/unittests/test_index_sample_op.py | 4 +- 2 files changed, 59 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 519e6ca1ab943..c97360345a4dc 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_sample_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { @@ -24,39 +25,47 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -__global__ void index_kernel(const IndexT* p_index, const T* p_input, - T* p_output, size_t stride_index, - size_t stride_input, size_t height) { +__global__ void IndexSampleForward(const IndexT* index, const T* in_data, + T* out_data, size_t index_length, + size_t input_length, size_t batch_size) { int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; - int tid = iy * stride_index + ix; - int tid_x = iy * stride_input + ix; - int tid_y = iy * stride_index + ix; + int tid = iy * index_length + ix; + int tid_x = iy * input_length + ix; - if (ix < stride_index & iy < height) { - IndexT idx = p_index[tid]; - p_output[tid_y] = p_input[tid_x - ix + idx]; + if (ix < index_length & iy < batch_size) { + IndexT idx = index[tid]; + out_data[tid] = in_data[tid_x - ix + idx]; } } template -__global__ void index_kernel_grad(const IndexT* p_index, T* p_input, - const T* p_output, size_t stride_index, - size_t stride_input, size_t height) { +__global__ void IndexSampleGradDefault(const IndexT* index, T* in_grad, + const T* out_grad, size_t index_length, + size_t input_length, size_t batch_size) { int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; - int tid = iy * stride_index + ix; - int tid_y = iy * stride_input + ix; - - if (ix < stride_index & iy < height) { - for (int i = 0; i < stride_index; ++i) { - if (ix == i) { - IndexT idx = p_index[tid]; - T tmp1 = p_output[tid]; - T tmp2 = p_input[tid_y - ix + idx]; - p_input[tid_y - ix + idx] = tmp1 + tmp2; - } - } + int tid = iy * index_length + ix; + int tid_y = iy * input_length + ix; + + if (ix < index_length & iy < batch_size) { + IndexT idx = index[tid]; + platform::CudaAtomicAdd(&(in_grad[tid_y - ix + idx]), out_grad[tid]); + } +} + +template +__global__ void IndexSampleGradSpecial(const IndexT* index, T* in_grad, + const T* out_grad, size_t index_length, + size_t input_length, size_t batch_size) { + int ix = blockDim.x * blockIdx.x + threadIdx.x; + int iy = blockDim.y * blockIdx.y + threadIdx.y; + int tid = iy * index_length + ix; + int tid_y = iy * input_length + ix; + + if (ix < index_length & iy < batch_size) { + IndexT idx = index[tid]; + in_grad[tid_y - ix + idx] = out_grad[tid]; } } @@ -84,24 +93,6 @@ class IndexSampleCUDAKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument("It must use CUDAPlace.")); - auto ComputeBlockSize = [](int col) { - if (col > 512) - return 1024; - else if (col > 256) - return 512; - else if (col > 128) - return 256; - else if (col > 64) - return 128; - else if (col > 32) - return 64; - else if (col > 16) - return 32; - else if (col > 8) - return 16; - else - return 8; - }; const auto* in_data = input->data(); auto* out_data = output->mutable_data(ctx.GetPlace()); auto stream = @@ -113,9 +104,9 @@ class IndexSampleCUDAKernel : public framework::OpKernel { size_t input_length = input_dim[1]; size_t index_length = index_dim[1]; - auto block_width = ComputeBlockSize(index_length); + auto block_width = platform::RoundToPowerOfTwo(index_length); int block_height = - ComputeBlockSize(index_length * batch_size) / block_width; + platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, @@ -123,12 +114,12 @@ class IndexSampleCUDAKernel : public framework::OpKernel { if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel<<>>( + IndexSampleForward<<>>( index_data, in_data, out_data, index_length, input_length, batch_size); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel<<>>( + IndexSampleForward<<>>( index_data, in_data, out_data, index_length, input_length, batch_size); } @@ -163,24 +154,6 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument("It must use CUDAPlace.")); - auto ComputeBlockSize = [](int col) { - if (col > 512) - return 1024; - else if (col > 256) - return 512; - else if (col > 128) - return 256; - else if (col > 64) - return 128; - else if (col > 32) - return 64; - else if (col > 16) - return 32; - else if (col > 8) - return 16; - else - return 8; - }; auto stream = ctx.template device_context().stream(); @@ -191,9 +164,9 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { size_t input_length = input_dim[1]; size_t index_length = index_dim[1]; - auto block_width = ComputeBlockSize(index_length); + auto block_width = platform::RoundToPowerOfTwo(index_length); auto block_height = - ComputeBlockSize(index_length * batch_size) / block_width; + platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, @@ -204,14 +177,26 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel_grad<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); + if (index_length == 1) { + IndexSampleGradSpecial<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size); + } else { + IndexSampleGradDefault<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size); + } } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel_grad<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); + if (index_length == 1) { + IndexSampleGradSpecial<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size); + } else { + IndexSampleGradDefault<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size); + } } PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index f640c0531192d..c1a8299592a2b 100644 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -92,9 +92,9 @@ def config(self): """ For int64 index type """ - self.x_shape = (10, 100) + self.x_shape = (10, 128) self.x_type = "float64" - self.index_shape = (10, 10) + self.index_shape = (10, 64) self.index_type = "int64" From d931a4a11402ac554f05badb13855fcc1270fa75 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Sun, 31 Jan 2021 05:59:50 +0000 Subject: [PATCH 6/9] Regularizing the writting codes of index-sample --- paddle/fluid/operators/index_sample_op.cu | 120 +++++++++------------- 1 file changed, 50 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index c97360345a4dc..196311d97324c 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -14,13 +14,13 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_sample_op.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; @@ -28,49 +28,41 @@ template __global__ void IndexSampleForward(const IndexT* index, const T* in_data, T* out_data, size_t index_length, size_t input_length, size_t batch_size) { - int ix = blockDim.x * blockIdx.x + threadIdx.x; - int iy = blockDim.y * blockIdx.y + threadIdx.y; - int tid = iy * index_length + ix; - int tid_x = iy * input_length + ix; - - if (ix < index_length & iy < batch_size) { - IndexT idx = index[tid]; - out_data[tid] = in_data[tid_x - ix + idx]; + int index_i = blockDim.x * blockIdx.x + threadIdx.x; + int index_j = blockDim.y * blockIdx.y + threadIdx.y; + int index_idx = index_j * index_length + index_i; + int in_idx = index_j * input_length + index_i; + + if (index_i < index_length & index_j < batch_size) { + IndexT sample_idx = index[index_idx]; + out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; } } template -__global__ void IndexSampleGradDefault(const IndexT* index, T* in_grad, - const T* out_grad, size_t index_length, - size_t input_length, size_t batch_size) { - int ix = blockDim.x * blockIdx.x + threadIdx.x; - int iy = blockDim.y * blockIdx.y + threadIdx.y; - int tid = iy * index_length + ix; - int tid_y = iy * input_length + ix; - - if (ix < index_length & iy < batch_size) { - IndexT idx = index[tid]; - platform::CudaAtomicAdd(&(in_grad[tid_y - ix + idx]), out_grad[tid]); - } -} - -template -__global__ void IndexSampleGradSpecial(const IndexT* index, T* in_grad, - const T* out_grad, size_t index_length, - size_t input_length, size_t batch_size) { - int ix = blockDim.x * blockIdx.x + threadIdx.x; - int iy = blockDim.y * blockIdx.y + threadIdx.y; - int tid = iy * index_length + ix; - int tid_y = iy * input_length + ix; - - if (ix < index_length & iy < batch_size) { - IndexT idx = index[tid]; - in_grad[tid_y - ix + idx] = out_grad[tid]; +__global__ void IndexSampleGrad(const IndexT* index, T* in_grad, + const T* out_grad, size_t index_length, + size_t input_length, size_t batch_size, + bool same_data_in_row = true) { + int index_i = blockDim.x * blockIdx.x + threadIdx.x; + int index_j = blockDim.y * blockIdx.y + threadIdx.y; + int index_idx = index_j * index_length + index_i; + int in_idx = index_j * input_length + index_i; + + if (index_i < index_length & index_j < batch_size) { + IndexT sample_idx = index[index_idx]; + if (same_data_in_row) { + platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]), + out_grad[sample_idx]); + } else { + in_grad[in_idx - index_i + sample_idx] = out_grad[sample_idx]; + } } } -template -class IndexSampleCUDAKernel : public framework::OpKernel { +template +class IndexSampleKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); @@ -127,8 +119,9 @@ class IndexSampleCUDAKernel : public framework::OpKernel { } }; -template -class IndexSampleGradCUDAKernel : public framework::OpKernel { +template +class IndexSampleGradKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* output_grad = ctx.Input(framework::GradVarName("Out")); @@ -156,47 +149,35 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { auto stream = ctx.template device_context().stream(); - auto input_num = input_grad->numel(); auto input_dim = input_grad->dims(); auto index_dim = index->dims(); size_t batch_size = index_dim[0]; size_t input_length = input_dim[1]; size_t index_length = index_dim[1]; + bool same_data_in_index_row = index_length == 1 ? false : true; auto block_width = platform::RoundToPowerOfTwo(index_length); auto block_height = platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; - dim3 block_dim(block_width, block_height); dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); - platform::GpuMemsetAsync(input_grad_data, 0, - sizeof(T) * input_length * batch_size, stream); + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, input_grad, static_cast(0)); if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - if (index_length == 1) { - IndexSampleGradSpecial<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); - } else { - IndexSampleGradDefault<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); - } + IndexSampleGrad<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size, same_data_in_index_row); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - if (index_length == 1) { - IndexSampleGradSpecial<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); - } else { - IndexSampleGradDefault<<>>( - index_data, input_grad_data, output_grad_data, index_length, - input_length, batch_size); - } + IndexSampleGrad<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size, same_data_in_index_row); } PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } @@ -208,14 +189,13 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( index_sample, - ops::IndexSampleCUDAKernel, - ops::IndexSampleCUDAKernel, - ops::IndexSampleCUDAKernel, - ops::IndexSampleCUDAKernel); + ops::IndexSampleKernel, + ops::IndexSampleKernel, + ops::IndexSampleKernel, + ops::IndexSampleKernel); REGISTER_OP_CUDA_KERNEL( index_sample_grad, - ops::IndexSampleGradCUDAKernel, - ops::IndexSampleGradCUDAKernel, - ops::IndexSampleGradCUDAKernel, - ops::IndexSampleGradCUDAKernel); + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel, + ops::IndexSampleGradKernel); From ddd52bd2ad77263a2173dbe8be06693468991082 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Mon, 1 Feb 2021 11:31:17 +0000 Subject: [PATCH 7/9] Deleting the sync codes. --- paddle/fluid/operators/index_sample_op.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 196311d97324c..f126fd1c60a86 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -179,7 +179,6 @@ class IndexSampleGradKernel index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size, same_data_in_index_row); } - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } }; From ce289f4550fc0f936c810fb9350c24656ff7a266 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Mon, 1 Feb 2021 11:49:25 +0000 Subject: [PATCH 8/9] Deleting the sync codes. --- paddle/fluid/operators/index_sample_op.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index f126fd1c60a86..24ea50b22915f 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -143,9 +143,6 @@ class IndexSampleGradKernel framework::proto::VarType::INT32), paddle::framework::DataTypeToString( framework::proto::VarType::INT64))); - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); auto stream = ctx.template device_context().stream(); From aadcd13e6704d22c678d6d49897b04d5b3b24ff1 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Wed, 3 Feb 2021 05:42:57 +0000 Subject: [PATCH 9/9] Deleting the sync operation and gpu.place check function in forward kernel. --- paddle/fluid/operators/index_sample_op.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 24ea50b22915f..c8488eefb984f 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -81,10 +81,6 @@ class IndexSampleKernel framework::proto::VarType::INT32), paddle::framework::DataTypeToString( framework::proto::VarType::INT64))); - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); - const auto* in_data = input->data(); auto* out_data = output->mutable_data(ctx.GetPlace()); auto stream = @@ -115,7 +111,6 @@ class IndexSampleKernel index_data, in_data, out_data, index_length, input_length, batch_size); } - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); } };