-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a inplace concat custom op based on CUDA VMM API (resubmitted) #9320
base: develop
Are you sure you want to change the base?
Conversation
… loop, avoiding redundant kv cache copy
Thanks for your contribution! |
@@ -0,0 +1,71 @@ | |||
#include "paddle/extension.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增的几个文件开头放一下版权声明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加上了
csrc/gpu/vtensor.cu
Outdated
const paddle::Tensor& append_state, | ||
bool transposed_input | ||
) { | ||
// std::cout << "vtensor_reserve_one_token 1 " << (uintptr_t)cache_transposed.data() << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类似注释都可以给删掉
csrc/setup_hip.py
Outdated
"./gpu/pass/remove_assign_out_pass.cc", | ||
"./gpu/pass/apply_vtensor_concat_pass.cc", | ||
"./gpu/vtensor.cu", # TODO: this haven't tested with hip |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件不需要更改,先暂时只在gpu下使用就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这几行删掉了
llm/predict/predictor.py
Outdated
if is_paddlenlp_ops_available(): | ||
import paddlenlp_ops | ||
inference_config.enable_custom_passes([ | ||
"remove_assign_out_pass", # remove the assign_out_ op at the end of while loop | ||
"apply_vtensor_concat_pass", # replace concat op with vtensor implementation | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_paddlenlp_ops_available(): | |
import paddlenlp_ops | |
inference_config.enable_custom_passes([ | |
"remove_assign_out_pass", # remove the assign_out_ op at the end of while loop | |
"apply_vtensor_concat_pass", # replace concat op with vtensor implementation | |
]) | |
try: | |
import remove_assign_out_pass, apply_vtensor_concat_pass from paddlenlp_ops | |
inference_config.enable_custom_passes([ | |
"remove_assign_out_pass", # remove the assign_out_ op at the end of while loop | |
"apply_vtensor_concat_pass", # replace concat op with vtensor implementation | |
]) | |
except: | |
pass |
这样修改吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddlenlp_ops里没有pass的对象,我换成了新加的算子vtensor_reserve_one_token
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9320 +/- ##
===========================================
- Coverage 52.92% 52.24% -0.69%
===========================================
Files 661 671 +10
Lines 107069 109655 +2586
===========================================
+ Hits 56670 57288 +618
- Misses 50399 52367 +1968 ☔ View full report in Codecov by Sentry. |
This Pull Request is stale because it has been open for 60 days with no activity. 当前Pull Request 60天内无活动,被标记为stale。 |
PR types
Performance optimization
PR changes
Others
Description
这一PR尝试为当前的大模型推理过程增加基于CUDA VMM API的inplace concat支持(原理类似于vAttention),从而避免在每一个解码步都复制一次整个KV Cache。
该功能暂时只实现了自定义算子,未来还需要增加相关的pass以自动适配其他模型。
目前这一PR在llama模型上应用了这一方案,在3072 input+1024 output的情况下大约有10%的提升。
目前主要的思路是:
vtensor_reserve_one_token自定义算子的语义大致如下:
目前可能存在的问题:
本PR还包括了以下两个PR的内容: