Skip to content

Commit

Permalink
[MODIFIED]: Changing the grad kernel into atomicAdd style, which woul…
Browse files Browse the repository at this point in the history
…d inevitably increase thread-safety once calcalating the backward step of index_sample OP, and one special CUDA kernel considering the condition that each line of index array only contains 1 element. Besides, thread-deployment in block was 2-demensions.
  • Loading branch information
JamesLim-sy committed Jan 22, 2021
1 parent 170e5e1 commit 98a9af7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 74 deletions.
129 changes: 57 additions & 72 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_sample_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
Expand All @@ -24,39 +25,47 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

template <typename T, typename IndexT = int>
__global__ void index_kernel(const IndexT* p_index, const T* p_input,
T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
T* out_data, size_t index_length,
size_t input_length, size_t batch_size) {
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * stride_index + ix;
int tid_x = iy * stride_input + ix;
int tid_y = iy * stride_index + ix;
int tid = iy * index_length + ix;
int tid_x = iy * input_length + ix;

if (ix < stride_index & iy < height) {
IndexT idx = p_index[tid];
p_output[tid_y] = p_input[tid_x - ix + idx];
if (ix < index_length & iy < batch_size) {
IndexT idx = index[tid];
out_data[tid] = in_data[tid_x - ix + idx];
}
}

template <typename T, typename IndexT = int>
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input,
const T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
__global__ void IndexSampleGradDefault(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size) {
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * stride_index + ix;
int tid_y = iy * stride_input + ix;

if (ix < stride_index & iy < height) {
for (int i = 0; i < stride_index; ++i) {
if (ix == i) {
IndexT idx = p_index[tid];
T tmp1 = p_output[tid];
T tmp2 = p_input[tid_y - ix + idx];
p_input[tid_y - ix + idx] = tmp1 + tmp2;
}
}
int tid = iy * index_length + ix;
int tid_y = iy * input_length + ix;

if (ix < index_length & iy < batch_size) {
IndexT idx = index[tid];
platform::CudaAtomicAdd(&(in_grad[tid_y - ix + idx]), out_grad[tid]);
}
}

template <typename T, typename IndexT = int>
__global__ void IndexSampleGradSpecial(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size) {
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * index_length + ix;
int tid_y = iy * input_length + ix;

if (ix < index_length & iy < batch_size) {
IndexT idx = index[tid];
in_grad[tid_y - ix + idx] = out_grad[tid];
}
}

Expand Down Expand Up @@ -84,24 +93,6 @@ class IndexSampleCUDAKernel : public framework::OpKernel<T> {
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));

auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
};
const auto* in_data = input->data<T>();
auto* out_data = output->mutable_data<T>(ctx.GetPlace());
auto stream =
Expand All @@ -113,22 +104,22 @@ class IndexSampleCUDAKernel : public framework::OpKernel<T> {
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];

auto block_width = ComputeBlockSize(index_length);
auto block_width = platform::RoundToPowerOfTwo(index_length);
int block_height =
ComputeBlockSize(index_length * batch_size) / block_width;
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;

dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_kernel<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
index_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
}
Expand Down Expand Up @@ -163,24 +154,6 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel<T> {
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));

auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256)
return 512;
else if (col > 128)
return 256;
else if (col > 64)
return 128;
else if (col > 32)
return 64;
else if (col > 16)
return 32;
else if (col > 8)
return 16;
else
return 8;
};
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();

Expand All @@ -191,9 +164,9 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel<T> {
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];

auto block_width = ComputeBlockSize(index_length);
auto block_width = platform::RoundToPowerOfTwo(index_length);
auto block_height =
ComputeBlockSize(index_length * batch_size) / block_width;
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;

dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
Expand All @@ -204,14 +177,26 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel<T> {

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_kernel_grad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
if (index_length == 1) {
IndexSampleGradSpecial<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
} else {
IndexSampleGradDefault<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
}
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
index_kernel_grad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
if (index_length == 1) {
IndexSampleGradSpecial<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
} else {
IndexSampleGradDefault<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def config(self):
"""
For int64 index type
"""
self.x_shape = (10, 100)
self.x_shape = (10, 128)
self.x_type = "float64"
self.index_shape = (10, 10)
self.index_shape = (10, 64)
self.index_type = "int64"


Expand Down

0 comments on commit 98a9af7

Please sign in to comment.