Skip to content

Commit

Permalink
Flashattention support qkvpacked and varlen (PaddlePaddle#63289)
Browse files Browse the repository at this point in the history
* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

* update doc

---------

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
  • Loading branch information
2 people authored and yinfan98 committed May 7, 2024
1 parent b814b7d commit afa9da2
Show file tree
Hide file tree
Showing 12 changed files with 1,580 additions and 74 deletions.
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,18 @@
func : flash_attn_grad
data_type: q

- backward_op : flash_attn_qkvpacked_grad
forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
Expand All @@ -894,6 +906,18 @@
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flash_attn_varlen_qkvpacked_grad
forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_with_sparse_mask_grad
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,18 @@
backward : flash_attn_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : flash_attn_qkvpacked
args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked
data_type : qkv
backward : flash_attn_qkvpacked_grad

- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand All @@ -1122,6 +1134,19 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flash_attn_varlen_qkvpacked
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked
data_type : qkv
intermediate : softmax_lse, seed_offset
backward : flash_attn_varlen_qkvpacked_grad

- op : flash_attn_with_sparse_mask
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) {
if (dqkv) {
dqkv->share_meta(qkv);
}
}

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq);

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/common/ddim.h"
#include "paddle/common/errors.h"
#include "paddle/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/box_coder.h"

Expand Down Expand Up @@ -433,6 +436,31 @@ void FlashAttnInferMeta(const MetaTensor& q,
seed_offset->set_dims({2});
}
}
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
const auto& qkvdims = qkv.dims();
PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5,
phi::errors::InvalidArgument(
"qkv dims must be 4(unpadded) or 5(padded batch)"));
// qkv [total_*,nheads/nheads_k+2,nheads_k,headdim]
auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]});
if (qkvdims.size() == 5) {
// qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim]
out_dims =
DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]};
}
out->set_dims(out_dims);
out->set_dtype(qkv.dtype());
out->set_layout(qkv.layout());
softmax->set_dtype(qkv.dtype());
softmax_lse->set_dtype(qkv.dtype());
if (seed_offset) {
seed_offset->set_dtype(phi::DataType::INT64);
}
}

void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
Loading

0 comments on commit afa9da2

Please sign in to comment.