diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 38a243d6eefbe2..3a8d4d2d2862c9 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -882,6 +882,18 @@ func : flash_attn_grad data_type: q +- backward_op : flash_attn_qkvpacked_grad + forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked_grad + data_type: qkv + - backward_op : flash_attn_unpadded_grad forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) @@ -894,6 +906,18 @@ func : flash_attn_unpadded_grad data_type: q +- backward_op : flash_attn_varlen_qkvpacked_grad + forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked_grad + data_type: qkv + - backward_op : flash_attn_with_sparse_mask_grad forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5c9a314435476d..a59f50b7a8ac64 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1109,6 +1109,18 @@ backward : flash_attn_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : flash_attn_qkvpacked + args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset, attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked + data_type : qkv + backward : flash_attn_qkvpacked_grad + - op : flash_attn_unpadded args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -1122,6 +1134,19 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad +- op : flash_attn_varlen_qkvpacked + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset , attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked + data_type : qkv + intermediate : softmax_lse, seed_offset + backward : flash_attn_varlen_qkvpacked_grad + - op : flash_attn_with_sparse_mask args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 9b8e94fc1380ce..c7574910504cd7 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q, } } +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) { + if (dqkv) { + dqkv->share_meta(qkv); + } +} + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 789419a54fde6f..63912e98d50f3b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -207,6 +207,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q, MetaTensor* dk, MetaTensor* dv); +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq); + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index fc43c09105bd60..f766d1e4dd9da9 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -17,7 +17,10 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/common/ddim.h" +#include "paddle/common/errors.h" #include "paddle/common/layout.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/impl/box_coder.h" @@ -433,6 +436,31 @@ void FlashAttnInferMeta(const MetaTensor& q, seed_offset->set_dims({2}); } } +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset) { + const auto& qkvdims = qkv.dims(); + PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5, + phi::errors::InvalidArgument( + "qkv dims must be 4(unpadded) or 5(padded batch)")); + // qkv [total_*,nheads/nheads_k+2,nheads_k,headdim] + auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]}); + if (qkvdims.size() == 5) { + // qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim] + out_dims = + DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]}; + } + out->set_dims(out_dims); + out->set_dtype(qkv.dtype()); + out->set_layout(qkv.layout()); + softmax->set_dtype(qkv.dtype()); + softmax_lse->set_dtype(qkv.dtype()); + if (seed_offset) { + seed_offset->set_dtype(phi::DataType::INT64); + } +} void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& end, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index a4e429fdd277d6..e1bc6615abfa43 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -115,6 +115,12 @@ void FlashAttnInferMeta(const MetaTensor& q, MetaTensor* softmax_lse, MetaTensor* seed_offset); +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset); + void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4f93288edaf14c..1e919c122bf033 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -13,12 +13,16 @@ // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include #include "glog/logging.h" // For VLOG() +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -31,26 +35,205 @@ int get_num_split() { return FLAGS_cudnn_deterministic ? 1 : 0; } +template +static __global__ void SumStridedKV(const T* src, + T* dst, + const uint64_t sRowDim1, + const uint64_t sRowDim2, + const uint64_t sRowDim3, + const uint64_t sColDim, + const uint64_t sRowStride1, + const uint64_t sRowStride2, + const uint64_t sColStride, + const uint64_t dRowStride1, + const uint64_t dRowStride2) { + // SrcShape [seqlen, num_heads_k, num_heads/num_heads_k, headdim] + // AxisName [row1 , row2 , col , row3 ] + // LoopMap [blockx, thready , serialreduce , threadx] + // Ensure blockDim.x == 32 && blockDim.z == 1 + // Ensure sRowStride3 == dRowStride3 == 1 (headdim dim is contiguous) + using IndexType = uint64_t; + constexpr IndexType BlockDimX = 32; + const IndexType SRow1Begin = blockIdx.x * sRowStride1; + const IndexType SRow1End = sRowDim1 * sRowStride1; + const IndexType SRow1Stride = gridDim.x * sRowStride1; + + const IndexType SRow2Begin = threadIdx.y * sRowStride2; + const IndexType SRow2End = sRowDim2 * sRowStride2; + const IndexType SRow2Stride = blockDim.y * sRowStride2; + + // const IndexType SRow3Begin = threadIdx.x * sRowStride3; + // const IndexType SRow3End = sRowDim3 * sRowStride3; + // const IndexType SRow3Stride = BlockDimX * sRowStride3; + + constexpr IndexType SColBegin = 0; + const IndexType SColEnd = sColDim * sColStride; + const IndexType SColStride = sColStride; + + const IndexType DRow1Begin = blockIdx.x * dRowStride1; + const IndexType DRow1Stride = gridDim.x * dRowStride1; + + const IndexType DRow2Begin = threadIdx.y * dRowStride2; + const IndexType DRow2Stride = dRowStride2; + + // const IndexType DRow3Begin = threadIdx.x * dRowStride3; + // const IndexType DRow3Stride = blockDim.x * dRowStride3; + + for (auto row1 = SRow1Begin, drow1 = DRow1Begin; row1 < SRow1End; + row1 += SRow1Stride, drow1 += DRow1Stride) { + for (auto row2 = SRow2Begin, drow2 = DRow2Begin; row2 < SRow2End; + row2 += SRow2Stride, drow2 += DRow2Stride) { + const auto i1 = row1 + row2 + threadIdx.x; + const auto di1 = drow1 + drow2 + threadIdx.x; + T v[HeaddimDiv32]; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] = T{0}; + } + for (auto col = SColBegin; col < SColEnd; col += SColStride) { + const auto i2 = i1 + col; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] += src[i2 + i * BlockDimX]; + } + } +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + dst[di1 + i * BlockDimX] = v[i]; + } + } + } +} + +template +static auto selectSumkernel(int64_t headdim) { + PADDLE_ENFORCE_LE(headdim, + 256, + phi::errors::InvalidArgument( + "FlashAttention only support headdim <= 256")); + PADDLE_ENFORCE_EQ(headdim % 32, + 0, + phi::errors::InvalidArgument( + "FlashAttention only support headdim %% 32 == 0")); + PADDLE_ENFORCE_NE( + headdim, 0, phi::errors::InvalidArgument("Headdim can't be zero")); +#define CASEN(n) \ + case n: \ + return SumStridedKV; + switch (headdim / 32) { + CASEN(1); + CASEN(2); + CASEN(3); + CASEN(4); + CASEN(5); + CASEN(6); + CASEN(7); + CASEN(8); + } + PADDLE_FATAL("Unreachable in selectSumKernel"); +#undef CASEN +} + template -void FlashAttnUnpaddedGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const paddle::optional& attn_mask, - const DenseTensor& dout, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +static void kvReduceForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[2], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + const int64_t reduceDimSize = dk_tmp.dims()[2]; + const size_t blockNum = + std::min((static_cast(dk_tmp.dims()[0] + 31) / 32), + static_cast(1024l)); + const dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[3]); + sumkernel<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0], + dk_tmp.dims()[1], + dk_tmp.dims()[3], + dk_tmp.dims()[2], + dk_tmp.strides()[0], + dk_tmp.strides()[1], + // dk_tmp.strides()[3], + dk_tmp.strides()[2], + dk->strides()[0], + dk->strides()[1] + // dk->strides()[2] + ); +} +template +static void kvReduceBatchedForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[4], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk->strides()[0], + dk->strides()[1] * dk->dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[0], + dk_tmp.strides()[1] * dk_tmp.dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); + const int64_t reduceDimSize = dk_tmp.dims()[3]; + const size_t blockNum = std::min( + (static_cast(dk_tmp.dims()[0] * dk_tmp.dims()[1] + 31) / 32), + static_cast(1024l)); + const dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[4]); + // here implicitly flat [batch,seqlen], and require batch dim to be contiguous + sumkernel<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0] * dk_tmp.dims()[1], + dk_tmp.dims()[2], + dk_tmp.dims()[4], + dk_tmp.dims()[3], + dk_tmp.strides()[1], + dk_tmp.strides()[2], + // dk_tmp.strides()[4], + dk_tmp.strides()[3], + dk->strides()[1], + dk->strides()[2] + // dk->strides()[3] + ); +} + +template +void FlashAttnUnpaddedGradBaseKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); @@ -64,37 +247,30 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } std::initializer_list dk_dv_shape = { total_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -139,9 +315,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, cu_seqlens_q.data(), cu_seqlens_k.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -162,20 +338,209 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out.strides()[0], + out.strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out.strides()[0], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + kdq->strides()[1], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[0], + dout.strides()[1], + max_seqlen_q * kdq->strides()[0], + max_seqlen_k * kdk->strides()[0], + max_seqlen_k * kdv->strides()[0], + max_seqlen_q * dout.strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + if (dk->meta().is_contiguous()) + phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + else + kvReduceForGQA(ctx, dk_tmp, dk); } if (dv) { - phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + if (dv->meta().is_contiguous()) + phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + else + kvReduceForGQA(ctx, dv_tmp, dv); } } #else RaiseNotSupportedError(); #endif } + +template +void FlashAttnUnpaddedGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + dq, + dk, + dv, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor* out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + *out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out->dtype())); +} +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; +template +void FlashAttnVarlenQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool varlen_padded, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // q,k,v [total_*, num_heads, head_dim] + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + { + std::vector inputs{}; + std::vector outputs{dqkv}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, &dq, 1, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 1, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + &dq, + &dk, + &dv, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} template void FlashAttnGradBaseKernel( const Context& ctx, @@ -208,36 +573,29 @@ void FlashAttnGradBaseKernel( bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + std::initializer_list dk_dv_shape = { + batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - std::initializer_list dk_dv_shape = { - batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -291,9 +649,9 @@ void FlashAttnGradBaseKernel( params.softmax_d.data(), softmax_lse.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -321,14 +679,45 @@ void FlashAttnGradBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out.strides()[1], + out.strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out.strides()[0], + kdq->strides()[1], + kdk->strides()[1], + kdv->strides()[1], + kdq->strides()[2], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[1], + dout.strides()[2], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + dout.strides()[0]); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + if (dk->meta().is_contiguous()) + phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + else + kvReduceBatchedForGQA(ctx, dk_tmp, dk); } + if (dv) { - phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + if (dv->meta().is_contiguous()) + phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + else + kvReduceBatchedForGQA(ctx, dv_tmp, dv); } } #else @@ -351,6 +740,15 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -369,6 +767,58 @@ void FlashAttnGradKernel(const Context& ctx, dv); } +template +void FlashAttnQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // qkv [batchsize, seqlen, nheads/nheads_k+2, nheads_k, head_dim] + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, &dq, 2, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 2, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 2, head_groupnum - 1, 1); + FlashAttnGradBaseKernel(ctx, + q, + k, + v, + out, + softmax_lse, + seed_offset, + attn_mask, + paddle::none, + dout, + dropout, + causal, + 0, + &dq, + &dk, + &dv); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseGradKernel( const Context& ctx, @@ -386,6 +836,15 @@ void FlashAttnWithSparseGradKernel( DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -414,6 +873,15 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad, kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_grad, GPU, ALL_LAYOUT, @@ -423,6 +891,15 @@ PD_REGISTER_KERNEL(flash_attn_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 7eb2d342feb792..64eb8450bcac62 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -14,17 +14,26 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" +#include #include "glog/logging.h" // For VLOG() #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" namespace phi { +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; template -void FlashAttnUnpaddedKernel( +void FlashAttnUnpaddedBaseKernel( const Context& ctx, const DenseTensor& q, const DenseTensor& k, @@ -44,10 +53,16 @@ void FlashAttnUnpaddedKernel( DenseTensor* out, DenseTensor* softmax, DenseTensor* softmax_lse, - DenseTensor* seed_offset) { + DenseTensor* seed_offset, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); + if (varlen_padded) { + std::vector inputs{}; + std::vector outputs{out}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } cudaStream_t stream = ctx.stream(); // q, k, v [total_q/k/v, num_heads, head_dim] @@ -120,13 +135,158 @@ void FlashAttnUnpaddedKernel( params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out->strides()[0], + out->strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out->strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); #endif } +template +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor* out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + *out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out->dtype())); +} +template +void FlashAttnVarlenQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + bool varlen_padded, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnBaseKernel( const Context& ctx, @@ -239,7 +399,19 @@ void FlashAttnBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out->strides()[1], + out->strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out->strides()[0]); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); @@ -281,6 +453,49 @@ void FlashAttnKernel(const Context& ctx, seed_offset); } +template +void FlashAttnQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); + FlashAttnBaseKernel(ctx, + q, + k, + v, + fixed_seed_offset, + attn_mask, + paddle::none, + dropout, + causal, + return_softmax, + is_test, + rng_name, + 0, + out, + softmax, + softmax_lse, + seed_offset); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseMaskKernel( const Context& ctx, @@ -330,6 +545,16 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn, GPU, ALL_LAYOUT, @@ -340,6 +565,16 @@ PD_REGISTER_KERNEL(flash_attn, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(1).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, GPU, ALL_LAYOUT, diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a9290887533764..7722ffb437389b 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -87,6 +87,8 @@ ) from .flash_attention import ( flash_attention_with_sparse_mask, + flash_attn_qkvpacked, + flash_attn_varlen_qkvpacked, scaled_dot_product_attention, sdp_kernel, # noqa: F401 ) @@ -279,5 +281,7 @@ 'gaussian_nll_loss', 'scaled_dot_product_attention', 'flash_attention_with_sparse_mask', + 'flash_attn_qkvpacked', + 'flash_attn_varlen_qkvpacked', 'group_norm', ] diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index e82684c32981de..84c7882a7151da 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -300,6 +300,158 @@ def flash_attention( ) +def flash_attn_qkvpacked( + qkv, + dropout=0.0, + causal=False, + return_softmax=False, + *, + fixed_seed_offset=None, + rng_name="", + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + Don't call this API if flash_attn is not supported. + + Args: + qkv(Tensor): The query/key/value packed tensor in the Attention module. + 5-D tensor with shape: + [batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + training(bool): Whether it is in the training phase. + rng_name(str): The name to select Generator. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + - out(Tensor). The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('flash_attn need A100 compile') + >>> import paddle + + >>> paddle.seed(2023) + >>> q = paddle.rand((1, 128, 2, 16)) + >>> qkv = paddle.stack([q, q, q], axis=2) + >>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False) + >>> print(output) + (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, + 0.42132431, 0.39157745], + [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286, + 0.76690865, 0.71485823]], + ..., + [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247, + 0.53336465, 0.54540104], + [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, + 0.40526697, 0.60541755]]]]), None) + >>> # doctest: -SKIP + + """ + head_dim = qkv.shape[-1] + sdp_func_name = _select_sdp(head_dim) + + if sdp_func_name == "flash_attn": + if in_dynamic_or_pir_mode(): + ( + result_attention, + result_softmax, + _, + _, + ) = _C_ops.flash_attn_qkvpacked( + qkv, + fixed_seed_offset, + None, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + else: + # don't call qkvpacked if not using flash_attn + query = qkv[:, :, :-2].reshape([0, 0, -1, qkv.shape[-1]]) + key = qkv[:, :, -2] + value = qkv[:, :, -1] + if sdp_func_name == "mem_efficient": + from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + ) + + output = memory_efficient_attention( + query, + key, + value, + attn_bias=None, + p=dropout, + scale=None, + training=training, + ) + return output, None + else: + return _math_attention( + query, + key, + value, + dropout_rate=dropout, + causal=causal, + return_softmax=return_softmax, + training=training, + ) + + def flash_attn_unpadded( query, key, @@ -439,6 +591,134 @@ def flash_attn_unpadded( return out, softmax if return_softmax else None +def flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + scale, + dropout=0.0, + causal=False, + return_softmax=False, + fixed_seed_offset=None, + rng_name="", + varlen_padded=True, + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + + Args: + qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed + 4-D tensor with shape: + [total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index query. + cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index key and value. + max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen + max_seqlen_k(int): Maximum sequence length of key/value in the batch. + scale(float): The scaling of QK^T before applying softmax. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + rng_name(str): The name to select Generator. + training(bool): Whether it is in the training phase. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + - out(Tensor). The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> # doctest: +SKIP('flash_attn need A100 compile') + >>> import paddle + >>> paddle.seed(2023) + >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') + >>> cu = paddle.arange(0, 384, 128, dtype='int32') + >>> qq = paddle.reshape(q, [256, 8, 16]) + >>> qkv = paddle.stack([qq, qq, qq], axis=2) + >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) + >>> # doctest: -SKIP + + """ + if in_dynamic_mode(): + ( + result_attention, + result_softmax, + ) = _C_ops.flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + None, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + not training, + rng_name, + varlen_padded, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_varlen_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_k, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_varlen_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_k, + 'scale': scale, + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + + def scaled_dot_product_attention( query, key, diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index a5eba6148db816..c410b743a78379 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -26,7 +26,9 @@ from paddle.nn.functional.flash_attention import ( flash_attention, flash_attention_with_sparse_mask, + flash_attn_qkvpacked, flash_attn_unpadded, + flash_attn_varlen_qkvpacked, scaled_dot_product_attention, ) from paddle.pir_utils import test_with_pir_api @@ -956,5 +958,422 @@ def setUp(self): self.causal = True +class TestFlashAttentionVarlenQKVPackedGQA(TestFlashAttentionGQA): + def gen_unpadded_data(self, dtype): + seq_len_q = np.random.randint( + low=1, high=self.seq_len, size=[self.batch_size] + ) + seq_len_k = seq_len_q + cu_seqlen_q = paddle.to_tensor( + [0] + np.cumsum(seq_len_q).tolist(), dtype=paddle.int32 + ) + cu_seqlen_k = cu_seqlen_q + + qs, ks, vs = [], [], [] + for i in range(self.batch_size): + tmp_q = ( + paddle.randn( + [seq_len_q[i] * self.num_head * self.head_dim], dtype=dtype + ) + / 1e2 + ) + tmp_k = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + tmp_v = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + qs.append(tmp_q) + ks.append(tmp_k) + vs.append(tmp_v) + + q = paddle.concat(qs, axis=0).reshape( + [-1, self.num_head, self.head_dim] + ) + k = paddle.concat(ks, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + v = paddle.concat(vs, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + return q, k, v, cu_seqlen_q, cu_seqlen_k + + def calc_qkvpackedfa( + self, q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad, causal, varlen_padded + ): + q, k, v = self.clone_tensor([q, k, v]) + scale = self.head_dim ** (-0.5) + if varlen_padded: + tq = q.reshape( + [ + self.batch_size * self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tk = k.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tv = v.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([tk, tv], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out_grad = out_grad.reshape(out[0].shape) + else: + tq = q.reshape( + [ + 0, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([k, v], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out = out[0] + grads = paddle.grad(outputs=out, inputs=qkv, grad_outputs=out_grad) + qkvgrad = grads[0] + out = out.reshape(q.shape) + qgrad = qkvgrad[:, :-2].reshape(q.shape) + kgrad = qkvgrad[:, -2].reshape(k.shape) + vgrad = qkvgrad[:, -1].reshape(v.shape) + if varlen_padded: + out = self.unpad(out, cu_seqlen_q) + qgrad = self.unpad(qgrad, cu_seqlen_q) + kgrad = self.unpad(kgrad, cu_seqlen_k) + vgrad = self.unpad(vgrad, cu_seqlen_k) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # if varlen_padded: + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionVarlenQKVPackedGQA2( + TestFlashAttentionVarlenQKVPackedGQA +): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPacked(TestFlashAttentionVarlenQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQA(TestFlashAttentionGQA): + def calc_qkvpackedfa(self, q, k, v, out_grad, causal): + # q, k, v = self.clone_tensor([q, k, v]) + tq = q.reshape( + [ + self.batch_size, + self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ], + ) + kv = paddle.stack([k, v], axis=2) + qkv = paddle.concat([tq, kv], axis=2) + (qkv,) = self.clone_tensor([qkv]) + out = flash_attn_qkvpacked(qkv, causal=causal) + out = out[0] + out.backward(out_grad) + qkvgrad = qkv.grad + qgrad = qkvgrad[:, :, :-2].reshape(q.shape) + kgrad = qkvgrad[:, :, -2].reshape(k.shape) + vgrad = qkvgrad[:, :, -1].reshape(v.shape) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionQKVPackedGQA2(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPacked(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedGQADeter( + TestFlashAttentionVarlenQKVPackedGQA +): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, + t2, + err_msg=f"Tensor{i} causal={causal} varlen_padded={varlen_padded}", + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionVarlenQKVPackedGQADeter2( +# TestFlashAttentionVarlenQKVPackedGQADeter +# ): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedDeter( + TestFlashAttentionVarlenQKVPackedGQADeter +): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQADeter(TestFlashAttentionQKVPackedGQA): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, t2, err_msg=f"Tensor{i} error, causal={causal}" + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionQKVPackedDeter2(TestFlashAttentionQKVPackedGQADeter): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedDeter(TestFlashAttentionQKVPackedGQADeter): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + if __name__ == '__main__': unittest.main() diff --git a/third_party/flashattn b/third_party/flashattn index d98d8a36cc9b88..22b604199d911d 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit d98d8a36cc9b884a1f405d187a0c41caeb5144c6 +Subproject commit 22b604199d911d4e155fe9e54124148c7a290263