Skip to content

Commit

Permalink
[WIP]: Very basic grad_function which passes the Ctest, in need of op…
Browse files Browse the repository at this point in the history
…timization desperately.
  • Loading branch information
JamesLim-sy committed Jan 17, 2021
1 parent 866dd28 commit 170e5e1
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,23 @@ __global__ void index_kernel(const IndexT* p_index, const T* p_input,
}

template <typename T, typename IndexT = int>
__global__ void index_kernel_grad(const IndexT_* p_index, T* p_input,
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input,
const T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
extern __shared__ T s_buf[];
int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * stride_index + ix;
int tid_y = iy * stride_input + ix;
s_buf[tid_y] = p_input[tid_y];
s_buf[tid_y] = 0;

if (ix < stride_index & iy < height) {
for (int i = 0; i < stride_index; ++i) {
if (ix == i) {
IndexT idx = p_index[tid];
s_buf[tid_y - ix + idx] += p_output[tid];
T tmp1 = p_output[tid];
T tmp2 = p_input[tid_y - ix + idx];
p_input[tid_y - ix + idx] = tmp1 + tmp2;
}
}
p_input[tid_y] = s_buf[tid_y];
}
}

Expand Down Expand Up @@ -201,17 +199,18 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel<T> {
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);

platform::GpuMemsetAsync(input_grad_data, 0,
sizeof(T) * input_length * batch_size, stream);

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_kernel_grad<
T, int64_t><<<grid_dim, block_dim, input_num * sizeof(T), stream>>>(
index_data, output_grad_data, input_grad_data, index_length,
index_kernel_grad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
index_kernel_grad<
T, int><<<grid_dim, block_dim, input_num * sizeof(T), stream>>>(
index_data, output_grad_data, input_grad_data, index_length,
index_kernel_grad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Expand Down

0 comments on commit 170e5e1

Please sign in to comment.