Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny. #34436

Merged
merged 1 commit into from
Aug 2, 2021

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Jul 28, 2021

PR types

Performance optimization

PR changes

OPs

Describe

Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny
解决了在brocast反向时,reduce最后一维很小时性能差的问题,且优化了 ReduceLastDim的实现,提升了部分case性能。

case axis pytorch 优化前 优化后 优化后相比pytorch 加速比
[4, 2048, 64, 128] [2, 3] 305.15us 328.35us 303.54us 打平 (0.53%) 1.08
[16, 2048, 7, 7] [2, 3] 21.738us 49.712us 16.133us 优于 (34.74%) 3.08
[512, 896, 4, 12] [3] 214.61us 2.5947ms 166.01us 优于 (29.28%) 15.63

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jul 28, 2021

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@AnnaTrainingG
Copy link
Contributor

LGTM

@Xreki Xreki changed the title Unify the block/grid strategy and implementation of ReduceLastDim and… Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny. Aug 2, 2021
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
template <typename T, typename ReduceOp>
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize;
__shared__ T shared[kWarpSize];
__shared__ T shared[2 * kWarpSize];
Copy link
Contributor

@Xreki Xreki Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么是2 * kWarpSize?后续可以考虑继续优化下这个函数,使得它可以作为一个基础函数,可以适用于如softmax、batch_norm等算子的使用场景。

@Xreki Xreki merged commit c7cc5ac into PaddlePaddle:develop Aug 2, 2021
@ZzSean ZzSean deleted the opt_reduce branch September 3, 2021 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants