Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.35】为 Paddle 优化 prelu op 在 GPU 上的计算性能 #51131

Merged
merged 8 commits into from
Mar 15, 2023

Conversation

thunder95
Copy link
Contributor

PR types

Performance optimization

PR changes

OPs

Describe

目前Paddle中的Prelu算子仍旧通过内部循环方式实现,没有用到一些性能优化的技巧,存在性能优化的空间。
设计文档: PaddlePaddle/community#370

  • 开发环境:
  1. 设备:RTX 2070s
  2. 环境:CUDA10.2,cuDNN 7
  • 优化方法
    通过使用飞桨内部kps的Elementwise Kernel 和 IndexKernel来进行计算。通过向量化读取、向量化写入对prelu算子进行优化.

完成优化后,Paddle与优化前的Paddle的性能对比效果:

Case No. input_shape weight_shape input_type paddle Perf(ms) old_paddle Perf(ms) diff
0 [8, 1024, 3072] [1L] float32 0.63895 0.8584 faster than 34.35%
1 [8, 1024, 3072] [1024L] float32 0.63628 1.1135 faster than 75%
2 [8, 1024, 3072] [1L] float16 0.31737 0.62442 faster than 96.75%
3 [8, 1024, 3072] [1024L] float16 0.47986 0.87672 faster than 82.7%

完成优化后,Paddle与Pytorch的性能对比效果如下:

Case No. input_shape weight_shape input_type paddle Perf(ms) pytorch Perf(ms) diff
0 [8, 1024, 3072] [1L] float32 0.63895 0.64366 faster than 0.737%
1 [8, 1024, 3072] [1024L] float32 0.63628 0.83144 faster than 30.67%
2 [8, 1024, 3072] [1L] float16 0.31737 0.31887 faster than 0.473%
3 [8, 1024, 3072] [1024L] float16 0.47986 0.84326 faster than 75.73%

针对四种不同case, 优化后的性能有不同程度的提升。

@paddle-bot
Copy link

paddle-bot bot commented Mar 2, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -43,7 +43,7 @@ __global__ void VectorizedIndexKernel(T *out,
out + data_offset, &result[0], BLOCK_NUM_X * VecSize);
}
size_t num = numel - data_offset;
if (num > 0) {
if (static_cast<int>(num) > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉不需要做 static_cast 的转换操作.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前跑benchmark的时候,定位了很久才发现这里一直报错,所以才修改了. @JamesLim-sy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来是这样,那这里保持即可,不用做修改.

size_t channel_num_;
size_t plane_size_;
int numel_;
const T zero = static_cast<T>(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不必作为成员变量,在HOSTDEVICE inline PReluChannelFirstWiseCUDAFunctor函数实现内,作为下述代码行即可:

constexpr T zero = static_cast<T>(0);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesLim-sy 已修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesLim-sy 使用constexpr或const,在编译时会报错, 我暂时先去掉。

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit 6bd5b7c into PaddlePaddle:develop Mar 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants