Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cuda kernel for index_sample. #30380

Merged
merged 10 commits into from
Feb 3, 2021

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Jan 13, 2021

PR types

Performance optimization

PR changes

OPs

Describe

  • 开发环境:
  1. 设备:V100-16G
  2. 环境:CUDA10.1,cuDNN 7
  • 优化方法:
  1. IndexSample OP的反向计算中使用了atomicAdd接口,保证计算时的线程安全性
  2. IndexSample OP的前向Kernel和反向Kernel中,均采用了2维的block和2维Grid,其目的是减少索引计算部分的开销;
  3. 由于Paddle中没有该OP没有GPU Kernel实现,因此主要与pytorch对比OP的性能
  • 优化效果:
No. index_shape input_shape Paddle Perf(ms) Pytorch Perf(ms) diff
1 [5100,1] [5100,38506] 0.7052 1.7032 faster than 58.5 97%
2 [100,64] [100, 128] 0.0055 0.0083 faster than 33.874%
3 [5100,96] [5100,128] 0.0323 0.0377 faster than 14.131%

@CLAassistant
Copy link

CLAassistant commented Jan 13, 2021

CLA assistant check
All committers have signed the CLA.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

using LoDTensor = framework::LoDTensor;

template <typename T, typename IndexT = int>
__global__ void index_kernel(const IndexT* p_index, const T* p_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码需遵循Google C++编程风格,函数命名为AxxBxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

编码规范确实会修改,下一次PR中这个问题会被处理掉。

template <typename T, typename IndexT = int>
__global__ void index_kernel(const IndexT* p_index, const T* p_input,
T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • p_index -> index, p_input -> input, p_output -> output,感觉这里不需要从命名上特意强调这是个ptr。
  • stride_index、stride_input、height,这几个参数我有点对应不上,变量命名能否更直观一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

指针加 "p_" 前缀是长期保持的习惯,后续修改成与paddle贴合的命名规范。

template <typename T, typename IndexT = int>
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input,
const T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从实际含义上来讲:p_index -> index,p_input -> in_grad,p_output -> out_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

paddle/fluid/operators/index_sample_op.cu Outdated Show resolved Hide resolved

dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • cuda并行方案怎么设计的,在PR描述里面补充下。
  • op benchmark里面可能要补充下配置,当前只有1个index_dim=1的配置,最好补充下index_dim>1的配置。
  • 另外看看单测里面有没有index_dim>1的配置

}

template <typename DeviceContext, typename T>
class IndexSampleCUDAKernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议改成特化index_sample.h中IndexSampleKernel类的形式。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有改?我是建议改成如下方式:

template <typename T>
class SumKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:

这样L92的检查就可以去掉了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经按要求修改

return 16;
else
return 8;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline static int RoundToPowerOfTwo(int dim) {
if (dim > 512) {
return 1024;
} else if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}

可使用这个函数代替吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,这块的写法非常不美观,肯定替换掉。

…d inevitably increase thread-safety once calcalating the backward step of index_sample OP, and one special CUDA kernel considering the condition that each line of index array only contains 1 element. Besides, thread-deployment in block was 2-demensions.
…d inevitably increase thread-safety once calcalating the backward step of index_sample OP, and one special CUDA kernel considering the condition that each line of index array only contains 1 element. Besides, thread-deployment in block was 2-demensions.
int tid = iy * index_length + ix;
int tid_x = iy * input_length + ix;

if (ix < index_length & iy < batch_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockDim.x最小值为32。当index_length<32时,一个block里面连续的32个线程会有空闲?后续可以再看看有没有更好的并行方案。

namespace paddle {
namespace operators {

using platform::PADDLE_CUDA_NUM_THREADS;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实际没有用到?

int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * index_length + ix;
int tid_x = iy * input_length + ix;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名可以起的更直观一些,比如你这里的ixiy应该是index数组里面的x和y下标,可以改成index_iindex_jtidindex数组里面的位置,也是out数组里面的位置,可以改成index_idxout_idxtid_x是in数组里面的位置,可以改成in_idx

int ix = blockDim.x * blockIdx.x + threadIdx.x;
int iy = blockDim.y * blockIdx.y + threadIdx.y;
int tid = iy * index_length + ix;
int tid_y = iy * input_length + ix;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名命名建议,同上。为什么前向的kernel里面叫tid_x,这个kernel里面叫tid_y呢?

paddle/fluid/operators/index_sample_op.cu Show resolved Hide resolved
}

template <typename DeviceContext, typename T>
class IndexSampleCUDAKernel : public framework::OpKernel<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有改?我是建议改成如下方式:

template <typename T>
class SumKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:

这样L92的检查就可以去掉了。

(batch_size + block_dim.y - 1) / block_dim.y);

platform::GpuMemsetAsync(input_grad_data, 0,
sizeof(T) * input_length * batch_size, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改成调用如下函数:

math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, d_x, static_cast<T>(0.0));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@paddle-bot-old
Copy link

Sorry to inform you that fec47c5's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

index_data, in_data, out_data, index_length, input_length,
batch_size);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op实现里面不用同步。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除同步操作

framework::proto::VarType::INT64)));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个检查可以删掉了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除检查判断

index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op实现里面不用同步。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续已删除

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改变量命名方式

return 16;
else
return 8;
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的,这块的写法非常不美观,肯定替换掉。

template <typename T, typename IndexT = int>
__global__ void index_kernel_grad(const IndexT* p_index, T* p_input,
const T* p_output, size_t stride_index,
size_t stride_input, size_t height) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

paddle/fluid/operators/index_sample_op.cu Show resolved Hide resolved
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续已删除

@JamesLim-sy JamesLim-sy requested a review from Xreki February 3, 2021 03:38
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants