Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] add Fused Attention kernel and API for SparseCsrTensor #43966

Merged
merged 2 commits into from
Jul 5, 2022

Conversation

zhwesky2010
Copy link
Contributor

@zhwesky2010 zhwesky2010 commented Jun 30, 2022

PR types

New features

PR changes

OPs

Describe

paddle.incubate.sparse.nn.functional.attention(query, key, value, sparse_mask, key_padding_mask, attn_mask)

import paddle

batch_size = 16
num_heads = 16
seq_len = 512
head_dim = 32

query = paddle.rand([batch_size, num_heads, seq_len, head_dim])
key = paddle.rand([batch_size, num_heads, seq_len, head_dim])
value = paddle.rand([batch_size, num_heads, seq_len, head_dim])

query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False

mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size*num_heads, seq_len, seq_len])
sp_mask = mask.to_sparse_csr()

kp_mask = paddle.randint(0, 2, [batch_size, seq_len]).astype('float32')
attn_mask = paddle.randint(0, 2, [seq_len, seq_len]).astype('float32')

output = paddle.incubate.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask)
# kp_mask, attn_mask 是optional
output = paddle.incubate.sparse.nn.functional.attention(query, key, value, sp_mask)
output.backward()

由于该API调用的 cusparseDnMatSetStridedBatchcusparseCsrSetStridedBatch等NV接口需要在CUDA11.7上才支持,因此CI无法直接运行,本地运行单测结果如下:

infoflow 2022-07-04 14-43-28

@zhwesky2010 zhwesky2010 changed the title add fused_attention kernel for SparseTensor [Sparse]add fused_attention kernel for SparseTensor Jun 30, 2022
@zhwesky2010 zhwesky2010 changed the title [Sparse]add fused_attention kernel for SparseTensor [Sparse] add fused_attention API and kernel of SparseCsrTensor Jul 1, 2022
@zhwesky2010 zhwesky2010 changed the title [Sparse] add fused_attention API and kernel of SparseCsrTensor [Sparse] add SparseCsrTensor fused_attention kernel and API Jul 1, 2022
@zhwesky2010 zhwesky2010 force-pushed the sparse_attention branch 3 times, most recently from 5f7862e to 8a0e330 Compare July 4, 2022 07:19
DenseTensor* dkey,
DenseTensor* dvalue) {
PD_THROW(
"Only support 'fused_attention' CPU backward kernel of SparseTensor now");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,146 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2021->2022

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zhwesky2010 zhwesky2010 merged commit 59813de into PaddlePaddle:develop Jul 5, 2022
@zhwesky2010 zhwesky2010 changed the title [Sparse] add SparseCsrTensor fused_attention kernel and API [Sparse] add Fused attention kernel and API for SparseCsrTensor Jul 5, 2022
@zhwesky2010 zhwesky2010 changed the title [Sparse] add Fused attention kernel and API for SparseCsrTensor [Sparse] add Fused Attention kernel and API for SparseCsrTensor Jul 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants