-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Sparse] add Fused Attention kernel and API for SparseCsrTensor #43966
[Sparse] add Fused Attention kernel and API for SparseCsrTensor #43966
Conversation
2719f19
to
52b61f6
Compare
5f7862e
to
8a0e330
Compare
8a0e330
to
7ef44da
Compare
DenseTensor* dkey, | ||
DenseTensor* dvalue) { | ||
PD_THROW( | ||
"Only support 'fused_attention' CPU backward kernel of SparseTensor now"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,146 @@ | |||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2021->2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
928ddc3
to
cfbaabb
Compare
cfbaabb
to
849deff
Compare
PR types
New features
PR changes
OPs
Describe
paddle.incubate.sparse.nn.functional.attention(query, key, value, sparse_mask, key_padding_mask, attn_mask)
由于该API调用的
cusparseDnMatSetStridedBatch
、cusparseCsrSetStridedBatch
等NV接口需要在CUDA11.7上才支持,因此CI无法直接运行,本地运行单测结果如下: