-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused attention op backward and python layer. #36498
Merged
lanxianghit
merged 62 commits into
PaddlePaddle:develop
from
limin2021:fused_attention_bw
Oct 26, 2021
Merged
Add fused attention op backward and python layer. #36498
lanxianghit
merged 62 commits into
PaddlePaddle:develop
from
limin2021:fused_attention_bw
Oct 26, 2021
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
zkh2016
previously approved these changes
Oct 25, 2021
… modify_fused_attention_functional_api_path
…github.com/limin2021/Paddle into fused_attention_bw
zkh2016
previously approved these changes
Oct 26, 2021
xingfeng01
previously approved these changes
Oct 26, 2021
lanxianghit
previously approved these changes
Oct 26, 2021
… fused_attention_bw
limin2021
dismissed stale reviews from lanxianghit, xingfeng01, and zkh2016
via
October 26, 2021 06:12
5f54a0f
xingfeng01
approved these changes
Oct 26, 2021
zkh2016
approved these changes
Oct 26, 2021
lanxianghit
approved these changes
Oct 26, 2021
TCChenlong
approved these changes
Oct 26, 2021
limin2021
added a commit
to limin2021/Paddle
that referenced
this pull request
Oct 26, 2021
功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
lanxianghit
pushed a commit
that referenced
this pull request
Oct 27, 2021
ghost
pushed a commit
to piotrekobi/Paddle
that referenced
this pull request
Nov 3, 2021
功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR types
New features
PR changes
OPs
Describe
功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
fused_attention_op 实现的计算逻辑:
fused_attention_op与paddle已有的MultiHeadAttention layer的不同:
(1)计算逻辑范围扩大了,详见上面的伪代码。
(2)q, k, v的weight存储格式不一样。
原有的:保存在三个weight张量中,WQ, WK, WV
本PR:保存在一个weight张量中,qkv_weight
由WQ, WK, WV得到qkv_weight的方法:
实现:
本PR是fused_attention_op 的反向实现,具体细节:
(1)fused_attention_op.cc and fused_attention_op.cu
The C++ impl of backward for fused_attention_op.
Related preceding RRs:
#34883, #35308, #35350 #35621 , #35903, #35905
(2)functional/fused_attention/fused_mult_head_attention():
Add static graph construction method.
(3)test_fused_attention_op.py
Add code to test the correctness of backward of fused_attention_op.
(4)fused_transformer.py/FusedMultiHeadAttention layer:
Add FusedMultiHeadAttention layer.
(5)test_fused_attention_op_api.py
Test the correctness of fused_attention_op python API, both dynamic and static graph.
Unittest results