Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.33】为 Paddle 优化 Histogram op 在 GPU 上的计算性能 #53112

Merged
merged 4 commits into from
Apr 25, 2023

Conversation

zeroRains
Copy link
Contributor

@zeroRains zeroRains commented Apr 20, 2023

PR types

Performance optimization

PR changes

OPs

Description

当前Paddle采用自主编写的CUDA Kernel执行Histogram的核心计算部分,但是在确定直方图边界时使用Eigen进行计算,存在一定的优化空间
设计文档:https://github.com/PaddlePaddle/community/blob/master/rfcs/OPs-Perf/20230328_histogram_op_optimization.md

  • 开发环境

    1. 设备:Tesla V100
    2. 环境:CUDA11.2,cuDNN 8
  • 优化方法

    • 关键是使用__global__ kernel的方式实现了KernelMinMax,加速Histogram确定直方图边界的计算部分,从而提高Histogram算子在GPU上的计算性能。

完成优化后,Paddle与优化前的Paddle的性能对比效果:

Case No. device input_shape input_type bins min max Paddle Perf(ms) old Paddle Perf(ms) diff
1 Tesla V100 [16, 64] int32 100 0 0 0.01176 0.09403 faster than 699.57%
2 Tesla V100 [16, 64] int64 100 0 0 0.01179 0.13624 faster than 1055.56%
3 Tesla V100 [16, 64] float32 100 0 0 0.01117 0.01889 faster than 69.11%

完成优化后,Paddle与Pytorch的性能对比效果如下:

Case No. device input_shape input_type bins min max Paddle Perf(ms) Pytorch Perf(ms) diff
1 Tesla V100 [16, 64] int32 100 0 0 0.01176 0.02255 faster than 91.75%
2 Tesla V100 [16, 64] int64 100 0 0 0.01179 0.03424 faster than 190.42%
3 Tesla V100 [16, 64] float32 100 0 0 0.01117 0.02250 faster than 101.43%

针对三种不同case, 优化后的性能有不同程度的提升。

@paddle-bot
Copy link

paddle-bot bot commented Apr 20, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Apr 20, 2023
@zeroRains zeroRains changed the title create KernelMinMax to optimize the performance of histogram op in GPU 【PaddlePaddle Hackathon 4 No.33】为 Paddle 优化 Histogram op 在 GPU 上的计算性能 Apr 20, 2023
__syncthreads();
CUDA_KERNEL_LOOP(index, total_elements) {
const auto input_value = input[index];
phi::CudaAtomicMin(&min_data, input_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda atomic 太暴力了,建议切换成串联block级别和 warp级别的 操作,可以进一步拉升性能。实在是太暴力了,review RFC 的时候,我看到的atomic以为是要改 histogram 直方图部分的计算。

Copy link
Contributor Author

@zeroRains zeroRains Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,性能确实能提升,同时消除了原先存在的DtoH的开销,已更新性能表,麻烦老师看看还有什么问题。

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit c1a61fc into PaddlePaddle:develop Apr 25, 2023
@zeroRains zeroRains deleted the histogram branch April 26, 2023 01:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants