diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index d064f5cbb76a9e..519e6ca1ab943d 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -40,25 +40,23 @@ __global__ void index_kernel(const IndexT* p_index, const T* p_input, } template -__global__ void index_kernel_grad(const IndexT_* p_index, T* p_input, +__global__ void index_kernel_grad(const IndexT* p_index, T* p_input, const T* p_output, size_t stride_index, size_t stride_input, size_t height) { - extern __shared__ T s_buf[]; int ix = blockDim.x * blockIdx.x + threadIdx.x; int iy = blockDim.y * blockIdx.y + threadIdx.y; int tid = iy * stride_index + ix; int tid_y = iy * stride_input + ix; - s_buf[tid_y] = p_input[tid_y]; - s_buf[tid_y] = 0; if (ix < stride_index & iy < height) { for (int i = 0; i < stride_index; ++i) { if (ix == i) { IndexT idx = p_index[tid]; - s_buf[tid_y - ix + idx] += p_output[tid]; + T tmp1 = p_output[tid]; + T tmp2 = p_input[tid_y - ix + idx]; + p_input[tid_y - ix + idx] = tmp1 + tmp2; } } - p_input[tid_y] = s_buf[tid_y]; } } @@ -201,17 +199,18 @@ class IndexSampleGradCUDAKernel : public framework::OpKernel { dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, (batch_size + block_dim.y - 1) / block_dim.y); + platform::GpuMemsetAsync(input_grad_data, 0, + sizeof(T) * input_length * batch_size, stream); + if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); - index_kernel_grad< - T, int64_t><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } else if (index_type == framework::proto::VarType::INT32) { const int* index_data = index->data(); - index_kernel_grad< - T, int><<>>( - index_data, output_grad_data, input_grad_data, index_length, + index_kernel_grad<<>>( + index_data, input_grad_data, output_grad_data, index_length, input_length, batch_size); } PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));