Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flashattention support qkvpacked and varlen #63289

Merged
merged 18 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,18 @@
func : flash_attn_grad
data_type: q

- backward_op : flash_attn_qkvpacked_grad
forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
Expand All @@ -882,6 +894,18 @@
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flash_attn_varlen_qkvpacked_grad
forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_with_sparse_mask_grad
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,18 @@
backward : flash_attn_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : flash_attn_qkvpacked
args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked
data_type : qkv
backward : flash_attn_qkvpacked_grad

- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand All @@ -1062,6 +1074,19 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flash_attn_varlen_qkvpacked
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked
data_type : qkv
intermediate : softmax_lse, seed_offset
backward : flash_attn_varlen_qkvpacked_grad

- op : flash_attn_with_sparse_mask
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) {
if (dqkv) {
dqkv->share_meta(qkv);
}
}

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq);

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/common/ddim.h"
#include "paddle/common/errors.h"
#include "paddle/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/box_coder.h"

Expand Down Expand Up @@ -390,6 +393,31 @@ void FlashAttnInferMeta(const MetaTensor& q,
seed_offset->set_dtype(phi::DataType::INT64);
}
}
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
const auto& qkvdims = qkv.dims();
PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5,
phi::errors::InvalidArgument(
"qkv dims must be 4(unpadded) or 5(padded batch)"));
// qkv [total_*,nheads/nheads_k+2,nheads_k,headdim]
auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]});
if (qkvdims.size() == 5) {
// qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim]
out_dims =
DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]};
}
out->set_dims(out_dims);
out->set_dtype(qkv.dtype());
out->set_layout(qkv.layout());
softmax->set_dtype(qkv.dtype());
softmax_lse->set_dtype(qkv.dtype());
if (seed_offset) {
seed_offset->set_dtype(phi::DataType::INT64);
}
Comment on lines +415 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个output可以设置dim吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是和原有的FlashAttnInferMeta一样的,这几个output的shape是在flash_attn_utils.h的FlashAttnFwdParamsV2中Resize的,没有在这里设置dim

}

void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
Loading