From 20a4d4c264f0e5a8fcda4bb636e380ab5d048afd Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Sun, 7 Apr 2024 19:33:03 +0800 Subject: [PATCH 01/16] Flashattention support qkvpacked and varlen --- paddle/phi/api/yaml/backward.yaml | 24 + paddle/phi/api/yaml/ops.yaml | 25 + paddle/phi/core/kernel_factory.cc | 6 +- paddle/phi/core/kernel_factory.h | 3 + paddle/phi/infermeta/backward.cc | 7 + paddle/phi/infermeta/backward.h | 2 + paddle/phi/infermeta/ternary.cc | 28 + paddle/phi/infermeta/ternary.h | 6 + .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 519 +++++++++++++++--- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 245 ++++++++- .../paddle/nn/functional/flash_attention.py | 278 +++++++++- test/legacy_test/test_flash_attention.py | 419 ++++++++++++++ third_party/flashattn | 2 +- 13 files changed, 1482 insertions(+), 82 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 603b65c8b4c53..626e0b1b79121 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -846,6 +846,18 @@ kernel : func : flash_attn_grad data_type: q + +- backward_op : flash_attn_qkvpacked_grad + forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked_grad + data_type: qkv - backward_op : flash_attn_unpadded_grad forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -869,6 +881,18 @@ kernel : func : flash_attn_with_sparse_mask_grad data_type: q + +- backward_op : flash_attn_varlen_qkvpacked_grad + forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) + optional : attn_mask + output : Tensor(qkv_grad) + infer_meta : + func : FlashAttnQKVPackedGradInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked_grad + data_type: qkv - backward_op : flatten_grad forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 918cbb980d00f..48b4312a18e1a 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1044,6 +1044,18 @@ backward : flash_attn_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : flash_attn_qkvpacked + args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset, attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_qkvpacked + data_type : qkv + backward : flash_attn_qkvpacked_grad + - op : flash_attn_unpadded args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -1069,6 +1081,19 @@ data_type : q backward : flash_attn_with_sparse_mask_grad +- op : flash_attn_varlen_qkvpacked + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset , attn_mask + infer_meta : + func : FlashAttnQKVPackedInferMeta + param : [qkv] + kernel : + func : flash_attn_varlen_qkvpacked + data_type : qkv + intermediate : softmax_lse, seed_offset + backward : flash_attn_varlen_qkvpacked_grad + - op : flatten args : (Tensor x, int start_axis = 1, int stop_axis = 1) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 32644cfe8bf63..4e642b7475aa4 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -270,7 +270,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( auto kernel_iter = iter->second.find( {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()}); if (kernel_iter != iter->second.end()) { - return {kernel_iter->second, false, false}; + return {kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; } kernel_key = KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype()); @@ -351,7 +351,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( << ", expected_kernel_key:" << kernel_key << ", fallbacking to CPU one!"; - return {kernel_iter->second, true, false}; + return {kernel_iter->second, true, kernel_iter->second.IsSupportStride()}; } PADDLE_ENFORCE_NE( @@ -366,7 +366,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_name, KernelSelectionErrorMessage(kernel_name, kernel_key))); - return {kernel_iter->second, false, false}; + return {kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; } const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 4bdb8482c43b7..5960a4dc3bc59 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -281,6 +281,8 @@ class Kernel { return kernel_registered_type_; } + bool IsSupportStride() const { return support_stride_; } + void SetSupportStride(bool support) { support_stride_ = support; } GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr}; std::function check_if_onednn_kernel_support_{ nullptr}; @@ -290,6 +292,7 @@ class Kernel { void* variadic_fn_ = nullptr; KernelArgsDef args_def_; KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION; + bool support_stride_ = false; }; using KernelKeyMap = paddle::flat_hash_map; diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 261b99512a0ff..4071887f49d43 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -244,6 +244,13 @@ void FlashAttnGradInferMeta(const MetaTensor& q, } } +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, + MetaTensor* dqkv) { + if (dqkv) { + dqkv->share_meta(qkv); + } +} + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 88aea8f18181b..86d94bc6498f5 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -197,6 +197,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q, MetaTensor* dk, MetaTensor* dv); +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq); + void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset, const MetaTensor& out_grad, MetaTensor* x_grad, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index f10a86b33836a..92f18a1991e30 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -17,7 +17,10 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/common/ddim.h" +#include "paddle/common/errors.h" #include "paddle/common/layout.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/impl/box_coder.h" @@ -371,6 +374,31 @@ void FlashAttnInferMeta(const MetaTensor& q, seed_offset->set_dtype(phi::DataType::INT64); } } +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset) { + const auto& qkvdims = qkv.dims(); + PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5, + phi::errors::InvalidArgument( + "qkv dims must be 4(unpadded) or 5(padded batch)")); + // qkv [total_*,nheads/nheads_k+2,nheads_k,headdim] + auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]}); + if (qkvdims.size() == 5) { + // qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim] + out_dims = + DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]}; + } + out->set_dims(out_dims); + out->set_dtype(qkv.dtype()); + out->set_layout(qkv.layout()); + softmax->set_dtype(qkv.dtype()); + softmax_lse->set_dtype(qkv.dtype()); + if (seed_offset) { + seed_offset->set_dtype(phi::DataType::INT64); + } +} void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& end, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index c1c1af6f08218..0a3d7f06c9105 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -98,6 +98,12 @@ void FlashAttnInferMeta(const MetaTensor& q, MetaTensor* softmax, MetaTensor* softmax_lse, MetaTensor* seed_offset); + +void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, + MetaTensor* out, + MetaTensor* softmax, + MetaTensor* softmax_lse, + MetaTensor* seed_offset) ; void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4f93288edaf14..ccd87704b541a 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -13,12 +13,15 @@ // limitations under the License. #include "paddle/phi/kernels/flash_attn_grad_kernel.h" +#include #include "glog/logging.h" // For VLOG() #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" @@ -30,27 +33,109 @@ int get_num_split() { // 0 for an internal heuristic, which is optimal return FLAGS_cudnn_deterministic ? 1 : 0; } +bool isContiguous(const DenseTensor& t) { + auto rank = t.dims().size(); + auto s = t.strides()[rank - 1]; + if (s != 1) return false; + for (auto i = rank - 1; i > 0;) { + s *= t.dims()[i]; + i--; + if (t.strides()[i] != s) { + return false; + } + } + return true; +} +template +__global__ void SumStridedKV(T* src, + T* dst, + size_t sRowDim1, + size_t sRowDim2, + size_t sRowDim3, + size_t sColDim, + size_t sRowStride1, + size_t sRowStride2, + size_t sRowStride3, + size_t sColStride, + size_t dRowStride1, + size_t dRowStride2, + size_t dRowStride3) { + for (size_t row1 = blockIdx.x; row1 < sRowDim1; row1 += gridDim.x) + for (size_t row2 = 0; row2 < sRowDim2; row2++) + for (size_t row3 = threadIdx.x; row3 < sRowDim3; row3 += blockDim.x) { + T v{0}; + for (size_t col = 0; col < sColDim; col++) { + v += src[row1 * sRowStride1 + row2 * sRowStride2 + + row3 * sRowStride3 + col * sColStride]; + } + dst[row1 * dRowStride1 + row2 * dRowStride2 + row3 * dRowStride3] = v; + } +} template -void FlashAttnUnpaddedGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& cu_seqlens_q, - const DenseTensor& cu_seqlens_k, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const paddle::optional& attn_mask, - const DenseTensor& dout, - int64_t max_seqlen_q, - int64_t max_seqlen_k, - float scale, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +void kvReduceForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor& dk) { + const size_t reduceDimSize = dk_tmp.dims()[2]; + const size_t blockNum = std::min((dk_tmp.dims()[0] + 127) / 128, 1024l); + SumStridedKV<<>>((T*)dk_tmp.data(), + (T*)dk.data(), + dk_tmp.dims()[0], + dk_tmp.dims()[1], + dk_tmp.dims()[3], + dk_tmp.dims()[2], + dk_tmp.strides()[0], + dk_tmp.strides()[1], + dk_tmp.strides()[3], + dk_tmp.strides()[2], + dk.strides()[0], + dk.strides()[1], + dk.strides()[2]); +} +template +void kvReduceBatchedForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor& dk) { + const size_t reduceDimSize = dk_tmp.dims()[3]; + const size_t blockNum = std::min((dk_tmp.dims()[0] + 127) / 128, 1024l); + // here implicitly flat [batch,seqlen], and require batch dim to be contiguous + SumStridedKV + <<>>((T*)dk_tmp.data(), + (T*)dk.data(), + dk_tmp.dims()[0] * dk_tmp.dims()[1], + dk_tmp.dims()[2], + dk_tmp.dims()[4], + dk_tmp.dims()[3], + dk_tmp.strides()[1], + dk_tmp.strides()[2], + dk_tmp.strides()[4], + dk_tmp.strides()[3], + dk.strides()[1], + dk.strides()[2], + dk.strides()[3]); +} +template +void FlashAttnUnpaddedGradBaseKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [total_*, num_heads, head_dim] auto dims = q.dims(); @@ -64,37 +149,30 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } std::initializer_list dk_dv_shape = { total_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -139,9 +217,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, cu_seqlens_q.data(), cu_seqlens_k.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -162,20 +240,209 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out.strides()[0], + out.strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out.strides()[0], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + kdq->strides()[1], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[0], + dout.strides()[1], + max_seqlen_q * kdq->strides()[0], + max_seqlen_k * kdk->strides()[0], + max_seqlen_k * kdv->strides()[0], + max_seqlen_q * dout.strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + if (isContiguous(*dk)) + phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); + else + kvReduceForGQA(ctx, dk_tmp, *dk); } if (dv) { - phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + if (isContiguous(*dv)) + phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); + else + kvReduceForGQA(ctx, dv_tmp, *dv); } } #else RaiseNotSupportedError(); #endif } + +template +void FlashAttnUnpaddedGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + dq, + dk, + dv, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor& out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out.set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out.dtype())); +} +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; +template +void FlashAttnVarlenQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool varlen_padded, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // q,k,v [total_*, num_heads, head_dim] + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, v, 1, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + { + std::vector inputs{}; + std::vector outputs{dqkv}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, dq, 1, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, dk, 1, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, dv, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedGradBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + out, + softmax_lse, + seed_offset, + attn_mask, + dout, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + &dq, + &dk, + &dv, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} template void FlashAttnGradBaseKernel( const Context& ctx, @@ -208,36 +475,29 @@ void FlashAttnGradBaseKernel( bool is_mha = (num_heads == num_heads_k); - void* dq_ptr = nullptr; - void* dk_ptr = nullptr; - void* dv_ptr = nullptr; - + std::initializer_list dk_dv_shape = { + batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; + DenseTensor* kdq = dq; DenseTensor dq_tmp; - if (dq) { - dq_ptr = ctx.template Alloc(dq); - } else { + if (!dq) { dq_tmp.Resize(dims); - dq_ptr = ctx.template Alloc(&dq_tmp); + ctx.template Alloc(&dq_tmp); + kdq = &dq_tmp; } + DenseTensor *kdk = dk, *kdv = dv; DenseTensor dk_tmp; - std::initializer_list dk_dv_shape = { - batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}; - if (dk && is_mha) { - ctx.template Alloc(dk); - dk_ptr = dk->data(); - } else { + if (!dk || !is_mha) { dk_tmp.Resize(dk_dv_shape); - dk_ptr = ctx.template Alloc(&dk_tmp); + ctx.template Alloc(&dk_tmp); + kdk = &dk_tmp; } DenseTensor dv_tmp; - if (dv && is_mha) { - ctx.template Alloc(dv); - dv_ptr = dv->data(); - } else { + if (!dv || !is_mha) { dv_tmp.Resize(dk_dv_shape); - dv_ptr = ctx.template Alloc(&dv_tmp); + ctx.template Alloc(&dv_tmp); + kdv = &dv_tmp; } const cudaStream_t stream = ctx.stream(); @@ -291,9 +551,9 @@ void FlashAttnGradBaseKernel( params.softmax_d.data(), softmax_lse.data(), params.rng_state.data(), - dq_ptr, - dk_ptr, - dv_ptr, + kdq->data(), + kdk->data(), + kdv->data(), params.dq_accum.data(), params.batch_size, params.max_seqlen_q, @@ -321,14 +581,45 @@ void FlashAttnGradBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out.strides()[1], + out.strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out.strides()[0], + kdq->strides()[1], + kdk->strides()[1], + kdv->strides()[1], + kdq->strides()[2], + kdk->strides()[kdk->strides().size() - 2], + kdv->strides()[kdv->strides().size() - 2], + dout.strides()[1], + dout.strides()[2], + kdq->strides()[0], + kdk->strides()[0], + kdv->strides()[0], + dout.strides()[0]); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + if (isContiguous(*dk)) + phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); + else + kvReduceBatchedForGQA(ctx, dk_tmp, *dk); } + if (dv) { - phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + if (isContiguous(*dv)) + phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); + else + kvReduceBatchedForGQA(ctx, dv_tmp, *dv); } } #else @@ -351,6 +642,15 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -369,6 +669,58 @@ void FlashAttnGradKernel(const Context& ctx, dv); } +template +void FlashAttnQKVPackedGradKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dqkv) { +#ifdef PADDLE_WITH_FLASHATTN + // qkv [batchsize, seqlen, nheads/nheads_k+2, nheads_k, head_dim] + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, v, 2, head_groupnum - 1, 1); + // DenseTensor dqkv_tmp; + if (!dqkv) { + return; + // dqkv is the only output. No need to compute if no dqkv + // dqkv_tmp.Resize(qkv.dims()); + // dqkv = &dqkv_tmp; + } + ctx.template Alloc(dqkv); + DenseTensor dq, dk, dv; + sliceFlattenView(*dqkv, dq, 2, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, dk, 2, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, dv, 2, head_groupnum - 1, 1); + FlashAttnGradBaseKernel(ctx, + q, + k, + v, + out, + softmax_lse, + seed_offset, + attn_mask, + paddle::none, + dout, + dropout, + causal, + 0, + &dq, + &dk, + &dv); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseGradKernel( const Context& ctx, @@ -386,6 +738,15 @@ void FlashAttnWithSparseGradKernel( DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { + if (dq) { + ctx.template Alloc(dq); + } + if (dk) { + ctx.template Alloc(dk); + } + if (dv) { + ctx.template Alloc(dv); + } FlashAttnGradBaseKernel(ctx, q, k, @@ -414,6 +775,15 @@ PD_REGISTER_KERNEL(flash_attn_unpadded_grad, kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_grad, GPU, ALL_LAYOUT, @@ -423,6 +793,15 @@ PD_REGISTER_KERNEL(flash_attn_grad, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, GPU, ALL_LAYOUT, @@ -430,4 +809,4 @@ PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset -} +} \ No newline at end of file diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 7eb2d342feb79..cc1ff47b452e9 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -14,17 +14,26 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" +#include #include "glog/logging.h" // For VLOG() #include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" namespace phi { +template +struct ZeroFunctor { + __device__ __forceinline__ OutT operator()() const { + return static_cast(0); + } +}; template -void FlashAttnUnpaddedKernel( +void FlashAttnUnpaddedBaseKernel( const Context& ctx, const DenseTensor& q, const DenseTensor& k, @@ -44,10 +53,16 @@ void FlashAttnUnpaddedKernel( DenseTensor* out, DenseTensor* softmax, DenseTensor* softmax_lse, - DenseTensor* seed_offset) { + DenseTensor* seed_offset, + bool varlen_padded) { #ifdef PADDLE_WITH_FLASHATTN ctx.template Alloc(out); + if (varlen_padded) { + std::vector inputs{}; + std::vector outputs{out}; + phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); + } cudaStream_t stream = ctx.stream(); // q, k, v [total_q/k/v, num_heads, head_dim] @@ -120,13 +135,158 @@ void FlashAttnUnpaddedKernel( params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + q.strides()[0], + k.strides()[0], + v.strides()[0], + q.strides()[1], + k.strides()[1], + v.strides()[1], + out->strides()[0], + out->strides()[1], + max_seqlen_q * q.strides()[0], + max_seqlen_k * k.strides()[0], + max_seqlen_k * v.strides()[0], + max_seqlen_q * out->strides()[0], + varlen_padded); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); #endif } +template +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + false /*varlen_padded*/); +#else + RaiseNotSupportedError(); +#endif +} + +static void sliceFlattenView(const DenseTensor& in, + DenseTensor& out, + int axis, + int64_t offset, + int64_t sliceLength) { + PADDLE_ENFORCE_LT( + axis, + in.dims().size(), + phi::errors::InvalidArgument("sliceView receive axis out of bound")); + std::array dimArr; + std::array strideArr; + auto id = dimArr.begin(), is = strideArr.begin(); + for (int i = 0; i < in.dims().size(); i++) { + if (i == axis) continue; + if (i == axis + 1) + *id = in.dims()[i] * sliceLength; + else + *id = in.dims()[i]; + *is = in.strides()[i]; + id++; + is++; + } + out = DenseTensor{ + in.Holder(), + DenseTensorMeta{in.dtype(), + DDim{dimArr.data(), in.dims().size() - 1}, + DDim(strideArr.data(), in.dims().size() - 1)}}; + out.set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out.dtype())); +} +template +void FlashAttnVarlenQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + bool varlen_padded, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, v, 1, head_groupnum - 1, 1); + FlashAttnUnpaddedBaseKernel(ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + out, + softmax, + softmax_lse, + seed_offset, + varlen_padded); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnBaseKernel( const Context& ctx, @@ -239,7 +399,19 @@ void FlashAttnBaseKernel( params.attn_mask_start_row_indices_tensor ? params.attn_mask_start_row_indices_dims.data() : nullptr, - params.attn_mask_start_row); + params.attn_mask_start_row, + q.strides()[1], + k.strides()[1], + v.strides()[1], + q.strides()[2], + k.strides()[2], + v.strides()[2], + out->strides()[1], + out->strides()[2], + q.strides()[0], + k.strides()[0], + v.strides()[0], + out->strides()[0]); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); @@ -281,6 +453,49 @@ void FlashAttnKernel(const Context& ctx, seed_offset); } +template +void FlashAttnQKVPackedKernel( + const Context& ctx, + const DenseTensor& qkv, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN + const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 + DenseTensor q, k, v; + sliceFlattenView(qkv, q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, v, 2, head_groupnum - 1, 1); + FlashAttnBaseKernel(ctx, + q, + k, + v, + fixed_seed_offset, + attn_mask, + paddle::none, + dropout, + causal, + return_softmax, + is_test, + rng_name, + 0, + out, + softmax, + softmax_lse, + seed_offset); +#else + RaiseNotSupportedError(); +#endif +} + template void FlashAttnWithSparseMaskKernel( const Context& ctx, @@ -330,6 +545,16 @@ PD_REGISTER_KERNEL(flash_attn_unpadded, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_varlen_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnVarlenQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(3).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn, GPU, ALL_LAYOUT, @@ -340,6 +565,16 @@ PD_REGISTER_KERNEL(flash_attn, phi::Backend::ALL_BACKEND); // fixed_seed_offset } +PD_REGISTER_KERNEL(flash_attn_qkvpacked, + GPU, + ALL_LAYOUT, + phi::FlashAttnQKVPackedKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(1).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, GPU, ALL_LAYOUT, @@ -348,4 +583,4 @@ PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, phi::dtype::bfloat16) { kernel->InputAt(4).SetBackend( phi::Backend::ALL_BACKEND); // fixed_seed_offset -} +} \ No newline at end of file diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index e82684c32981d..0cce07d837daf 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -72,9 +72,7 @@ def _math_attention( query = paddle.transpose(query, [0, 2, 1, 3]) key = paddle.transpose(key, [0, 2, 1, 3]) value = paddle.transpose(value, [0, 2, 1, 3]) - product = paddle.matmul( - x=query * (head_dim**-0.5), y=key, transpose_y=True - ) + product = paddle.matmul(x=query * (head_dim**-0.5), y=key, transpose_y=True) if not causal: weights = F.softmax(product) @@ -300,6 +298,152 @@ def flash_attention( ) +def flash_attn_qkvpacked( + qkv, + dropout=0.0, + causal=False, + return_softmax=False, + *, + fixed_seed_offset=None, + rng_name="", + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + qkv(Tensor): The query/key/value packed tensor in the Attention module. + 5-D tensor with shape: + [batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + training(bool): Whether it is in the training phase. + rng_name(str): The name to select Generator. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + softmax(Tensor): The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.seed(2023) + >>> q = paddle.rand((1, 128, 2, 16)) + >>> qkv = paddle.stack([q,q,q], axis=3) + >>> output = paddle.nn.functional.flash_attention.flash_attn_qkvpacked(qkv, 0.9, False, False) + >>> print(output) + (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, + 0.42132431, 0.39157745], + [0.76687670, 0.65837246, 0.69117945, ..., 0.82817286, + 0.76690865, 0.71485823]], + ..., + [[0.71662450, 0.57275224, 0.57053083, ..., 0.48108247, + 0.53336465, 0.54540104], + [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, + 0.40526697, 0.60541755]]]]), None) + + """ + head_dim = qkv.shape[-1] + sdp_func_name = _select_sdp(head_dim) + + if sdp_func_name == "flash_attn": + if in_dynamic_or_pir_mode(): + (result_attention, result_softmax, _, _) = _C_ops.flash_attn_qkvpacked( + qkv, + fixed_seed_offset, + None, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + else: + # don't call qkvpacked if not using flash_attn + query = qkv[:, :, :-2].reshape([0, 0, -1, qkv.shape[-1]]) + key = qkv[:, :, -2] + value = qkv[:, :, -1] + if sdp_func_name == "mem_efficient": + from paddle.incubate.nn.memory_efficient_attention import ( + memory_efficient_attention, + ) + + output = memory_efficient_attention( + query, + key, + value, + attn_bias=None, + p=dropout, + scale=None, + training=training, + ) + return output, None + else: + return _math_attention( + query, + key, + value, + dropout_rate=dropout, + causal=causal, + return_softmax=return_softmax, + training=training, + ) + + def flash_attn_unpadded( query, key, @@ -439,6 +583,134 @@ def flash_attn_unpadded( return out, softmax if return_softmax else None +def flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + scale, + dropout=0.0, + causal=False, + return_softmax=False, + fixed_seed_offset=None, + rng_name="", + varlen_padded=True, + training=True, + name=None, +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed + 4-D tensor with shape: + [total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim]. + The dtype can be float16 or bfloat16. + cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index query. + cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch, + used to index key and value. + max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen + max_seqlen_k(int): Maximum sequence length of key/value in the batch. + scale(float): The scaling of QK^T before applying softmax. + dropout(float): The dropout ratio. + causal(bool): Whether enable causal mode. + return_softmax(bool): Whether to return softmax. + fixed_seed_offset(Tensor, optional): With fixed seed, offset for dropout mask. + rng_name(str): The name to select Generator. + training(bool): Whether it is in the training phase. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + out(Tensor): The attention tensor. The tensor is padded by zeros. + 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + softmax(Tensor): The softmax tensor. None if return_softmax is False. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.seed(2023) + >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') + >>> cu = paddle.arange(0, 384, 128, dtype='int32') + >>> qq = paddle.reshape(q, [256, 8, 16]) + >>> qkv = paddle.stack([qq,qq,qq], axis=2) + >>> output = paddle.nn.functional.flash_attention.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) + + """ + if in_dynamic_mode(): + ( + result_attention, + result_softmax, + ) = _C_ops.flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + None, + max_seqlen_q, + max_seqlen_k, + scale, + dropout, + causal, + return_softmax, + not training, + rng_name, + varlen_padded, + ) + return result_attention, result_softmax if return_softmax else None + + helper = LayerHelper('flash_attn_varlen_qkvpacked', **locals()) + dtype = helper.input_dtype(input_param_name='qkv') + out = helper.create_variable_for_type_inference(dtype) + softmax = helper.create_variable_for_type_inference(dtype) + softmax_lse = helper.create_variable_for_type_inference(paddle.float32) + seed_offset = helper.create_variable_for_type_inference(paddle.int64) + inputs = { + 'qkv': qkv, + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_k, + 'fixed_seed_offset': fixed_seed_offset, + } + outputs = { + 'out': out, + 'softmax': softmax, + 'softmax_lse': softmax_lse, + 'seed_offset': seed_offset, + } + helper.append_op( + type='flash_attn_varlen_qkvpacked', + inputs=inputs, + outputs=outputs, + attrs={ + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_k, + 'scale': scale, + 'dropout': dropout, + 'causal': causal, + 'return_softmax': return_softmax, + 'is_test': not training, + 'rng_name': rng_name, + }, + ) + return out, softmax if return_softmax else None + + def scaled_dot_product_attention( query, key, diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 343cb02e216d2..c8f9509bd75e8 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -27,6 +27,8 @@ flash_attention, flash_attention_with_sparse_mask, flash_attn_unpadded, + flash_attn_varlen_qkvpacked, + flash_attn_qkvpacked, scaled_dot_product_attention, ) from paddle.pir_utils import test_with_pir_api @@ -957,5 +959,422 @@ def setUp(self): self.causal = True +class TestFlashAttentionVarlenQKVPackedGQA(TestFlashAttentionGQA): + def gen_unpadded_data(self, dtype): + seq_len_q = np.random.randint( + low=1, high=self.seq_len, size=[self.batch_size] + ) + seq_len_k = seq_len_q + cu_seqlen_q = paddle.to_tensor( + [0] + np.cumsum(seq_len_q).tolist(), dtype=paddle.int32 + ) + cu_seqlen_k = cu_seqlen_q + + qs, ks, vs = [], [], [] + for i in range(self.batch_size): + tmp_q = ( + paddle.randn( + [seq_len_q[i] * self.num_head * self.head_dim], dtype=dtype + ) + / 1e2 + ) + tmp_k = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + tmp_v = ( + paddle.randn( + [ + seq_len_k[i] + * self.num_head + * self.head_dim + // self.num_group + ], + dtype=dtype, + ) + / 1e2 + ) + qs.append(tmp_q) + ks.append(tmp_k) + vs.append(tmp_v) + + q = paddle.concat(qs, axis=0).reshape( + [-1, self.num_head, self.head_dim] + ) + k = paddle.concat(ks, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + v = paddle.concat(vs, axis=0).reshape( + [-1, self.num_head // self.num_group, self.head_dim] + ) + return q, k, v, cu_seqlen_q, cu_seqlen_k + + def calc_qkvpackedfa( + self, q, k, v, cu_seqlen_q, cu_seqlen_k, out_grad, causal, varlen_padded + ): + q, k, v = self.clone_tensor([q, k, v]) + scale = self.head_dim ** (-0.5) + if varlen_padded: + tq = q.reshape( + [ + self.batch_size * self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tk = k.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + tv = v.reshape( + [ + self.batch_size * self.seq_len, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([tk, tv], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out_grad = out_grad.reshape(out[0].shape) + else: + tq = q.reshape( + [ + 0, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ] + ) + kv = paddle.stack([k, v], axis=1) + qkv = paddle.concat([tq, kv], axis=1) + out = flash_attn_varlen_qkvpacked( + qkv, + cu_seqlens_q=cu_seqlen_q, + cu_seqlens_k=cu_seqlen_k, + max_seqlen_q=self.seq_len, + max_seqlen_k=self.seq_len, + scale=scale, + causal=causal, + varlen_padded=varlen_padded, + ) + out = out[0] + grads = paddle.grad(outputs=out, inputs=qkv, grad_outputs=out_grad) + qkvgrad = grads[0] + out = out.reshape(q.shape) + qgrad = qkvgrad[:, :-2].reshape(q.shape) + kgrad = qkvgrad[:, -2].reshape(k.shape) + vgrad = qkvgrad[:, -1].reshape(v.shape) + if varlen_padded: + out = self.unpad(out, cu_seqlen_q) + qgrad = self.unpad(qgrad, cu_seqlen_q) + kgrad = self.unpad(kgrad, cu_seqlen_k) + vgrad = self.unpad(vgrad, cu_seqlen_k) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # if varlen_padded: + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionVarlenQKVPackedGQA2( + TestFlashAttentionVarlenQKVPackedGQA +): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPacked(TestFlashAttentionVarlenQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQA(TestFlashAttentionGQA): + def calc_qkvpackedfa(self, q, k, v, out_grad, causal): + # q, k, v = self.clone_tensor([q, k, v]) + tq = q.reshape( + [ + self.batch_size, + self.seq_len, + self.num_group, + self.num_head // self.num_group, + self.head_dim, + ], + ) + kv = paddle.stack([k, v], axis=2) + qkv = paddle.concat([tq, kv], axis=2) + (qkv,) = self.clone_tensor([qkv]) + out = flash_attn_qkvpacked(qkv, causal=causal) + out = out[0] + out.backward(out_grad) + qkvgrad = qkv.grad + qgrad = qkvgrad[:, :, :-2].reshape(q.shape) + kgrad = qkvgrad[:, :, -2].reshape(k.shape) + vgrad = qkvgrad[:, :, -1].reshape(v.shape) + return self.convert_dtype([out, qgrad, kgrad, vgrad]) + + def test_main(self): + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_raw_attn( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_allclose(t1, t2, atol=1e-2, rtol=1e-2) + + +class TestFlashAttentionQKVPackedGQA2(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 2 + self.num_head = 16 + self.seq_len = 2048 + self.head_dim = 128 + self.num_group = 4 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPacked(TestFlashAttentionQKVPackedGQA): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedGQADeter( + TestFlashAttentionVarlenQKVPackedGQA +): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + for varlen_padded in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, True) + if varlen_padded: + q_pad, _ = self.pad(q, cu_seqlen_q, self.seq_len) + k_pad, _ = self.pad(k, cu_seqlen_k, self.seq_len) + v_pad, _ = self.pad(v, cu_seqlen_k, self.seq_len) + out_grad_pad, _ = self.pad( + out_grad, cu_seqlen_q, self.seq_len + ) + else: + q_pad = q + k_pad = k + v_pad = v + out_grad_pad = out_grad + fa_out = self.calc_qkvpackedfa( + q_pad, + k_pad, + v_pad, + cu_seqlen_q, + cu_seqlen_k, + out_grad_pad, + causal, + varlen_padded, + ) + # cu_seqlen_q = None + # cu_seqlen_k = None + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + True, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, + t2, + err_msg=f"Tensor{i} causal={causal} varlen_padded={varlen_padded}", + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionVarlenQKVPackedGQADeter2( +# TestFlashAttentionVarlenQKVPackedGQADeter +# ): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionVarlenQKVPackedDeter( + TestFlashAttentionVarlenQKVPackedGQADeter +): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedGQADeter(TestFlashAttentionQKVPackedGQA): + def test_main(self): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + for causal in [False, True]: + ( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + ) = self.gen_test_data(self.dtype, False) + fa_out = self.calc_qkvpackedfa(q, k, v, out_grad, causal) + raw_out = self.calc_fa( + q, + k, + v, + cu_seqlen_q, + cu_seqlen_k, + out_grad, + causal, + False, + ) + assert len(fa_out) == len(raw_out) + i = 0 + for t1, t2 in zip(fa_out, raw_out): + np.testing.assert_array_equal( + t1, t2, err_msg=f"Tensor{i} error, causal={causal}" + ) + i += 1 + paddle.set_flags({'FLAGS_cudnn_deterministic': 0}) + + +# can't bit-match dk,dv now when num_group more than 2, since the sum kernel is different and sum sequence not defined +# class TestFlashAttentionQKVPackedDeter2(TestFlashAttentionQKVPackedGQADeter): +# def setUp(self): +# self.batch_size = 2 +# self.num_head = 16 +# self.seq_len = 2048 +# self.head_dim = 128 +# self.num_group = 4 +# self.dtype = 'bfloat16' + + +class TestFlashAttentionQKVPackedDeter(TestFlashAttentionQKVPackedGQADeter): + def setUp(self): + self.batch_size = 3 + self.num_head = 7 + self.seq_len = 563 + self.head_dim = 64 + self.num_group = 1 + self.dtype = 'bfloat16' + + if __name__ == '__main__': unittest.main() diff --git a/third_party/flashattn b/third_party/flashattn index d98d8a36cc9b8..2521ed0f49750 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit d98d8a36cc9b884a1f405d187a0c41caeb5144c6 +Subproject commit 2521ed0f49750e45cb0a9769aaf445e28d776809 From bec1299ca2e28fdecd1dd057ea695f82362880b3 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Mon, 8 Apr 2024 10:00:59 +0800 Subject: [PATCH 02/16] fix codestyle --- paddle/phi/api/yaml/backward.yaml | 26 ++--- paddle/phi/api/yaml/ops.yaml | 24 ++-- paddle/phi/core/kernel_factory.cc | 3 +- paddle/phi/infermeta/backward.cc | 3 +- paddle/phi/infermeta/ternary.h | 4 +- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 107 +++++++++--------- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 22 ++-- .../paddle/nn/functional/flash_attention.py | 11 +- test/legacy_test/test_flash_attention.py | 2 +- 9 files changed, 106 insertions(+), 96 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 626e0b1b79121..b394f61d93d4d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -846,7 +846,7 @@ kernel : func : flash_attn_grad data_type: q - + - backward_op : flash_attn_qkvpacked_grad forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) @@ -871,17 +871,6 @@ func : flash_attn_unpadded_grad data_type: q -- backward_op : flash_attn_with_sparse_mask_grad - forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) - output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) - infer_meta : - func : FlashAttnGradInferMeta - param : [q, k, v] - kernel : - func : flash_attn_with_sparse_mask_grad - data_type: q - - backward_op : flash_attn_varlen_qkvpacked_grad forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) @@ -892,7 +881,18 @@ param : [qkv] kernel : func : flash_attn_varlen_qkvpacked_grad - data_type: qkv + data_type: qkv + +- backward_op : flash_attn_with_sparse_mask_grad + forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + infer_meta : + func : FlashAttnGradInferMeta + param : [q, k, v] + kernel : + func : flash_attn_with_sparse_mask_grad + data_type: q - backward_op : flatten_grad forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 48b4312a18e1a..1e73ff769bddd 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1069,18 +1069,6 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad -- op : flash_attn_with_sparse_mask - args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") - output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - optional : fixed_seed_offset - infer_meta : - func : FlashAttnInferMeta - param : [q, k, v] - kernel : - func : flash_attn_with_sparse_mask - data_type : q - backward : flash_attn_with_sparse_mask_grad - - op : flash_attn_varlen_qkvpacked args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) @@ -1094,6 +1082,18 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_varlen_qkvpacked_grad +- op : flash_attn_with_sparse_mask + args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset + infer_meta : + func : FlashAttnInferMeta + param : [q, k, v] + kernel : + func : flash_attn_with_sparse_mask + data_type : q + backward : flash_attn_with_sparse_mask_grad + - op : flatten args : (Tensor x, int start_axis = 1, int stop_axis = 1) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 4e642b7475aa4..0bf8403f15016 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -270,7 +270,8 @@ KernelResult KernelFactory::SelectKernelOrThrowError( auto kernel_iter = iter->second.find( {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()}); if (kernel_iter != iter->second.end()) { - return {kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; + return { + kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; } kernel_key = KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype()); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 4071887f49d43..8f4e48601e783 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -244,8 +244,7 @@ void FlashAttnGradInferMeta(const MetaTensor& q, } } -void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, - MetaTensor* dqkv) { +void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) { if (dqkv) { dqkv->share_meta(qkv); } diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 0a3d7f06c9105..f2490eac2f347 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -98,12 +98,12 @@ void FlashAttnInferMeta(const MetaTensor& q, MetaTensor* softmax, MetaTensor* softmax_lse, MetaTensor* seed_offset); - + void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv, MetaTensor* out, MetaTensor* softmax, MetaTensor* softmax_lse, - MetaTensor* seed_offset) ; + MetaTensor* seed_offset); void InstanceNormInferMeta(const MetaTensor& x, const MetaTensor& scale, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index ccd87704b541a..2c2df92678ca5 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -47,7 +47,7 @@ bool isContiguous(const DenseTensor& t) { return true; } template -__global__ void SumStridedKV(T* src, +__global__ void SumStridedKV(const T* src, T* dst, size_t sRowDim1, size_t sRowDim2, @@ -75,44 +75,47 @@ __global__ void SumStridedKV(T* src, template void kvReduceForGQA(const Context& ctx, const DenseTensor& dk_tmp, - DenseTensor& dk) { + DenseTensor* dk) { const size_t reduceDimSize = dk_tmp.dims()[2]; - const size_t blockNum = std::min((dk_tmp.dims()[0] + 127) / 128, 1024l); - SumStridedKV<<>>((T*)dk_tmp.data(), - (T*)dk.data(), - dk_tmp.dims()[0], - dk_tmp.dims()[1], - dk_tmp.dims()[3], - dk_tmp.dims()[2], - dk_tmp.strides()[0], - dk_tmp.strides()[1], - dk_tmp.strides()[3], - dk_tmp.strides()[2], - dk.strides()[0], - dk.strides()[1], - dk.strides()[2]); + const size_t blockNum = + std::min((static_cast(dk_tmp.dims()[0] + 127) / 128), + static_cast(1024l)); + SumStridedKV<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0], + dk_tmp.dims()[1], + dk_tmp.dims()[3], + dk_tmp.dims()[2], + dk_tmp.strides()[0], + dk_tmp.strides()[1], + dk_tmp.strides()[3], + dk_tmp.strides()[2], + dk->strides()[0], + dk->strides()[1], + dk->strides()[2]); } template void kvReduceBatchedForGQA(const Context& ctx, const DenseTensor& dk_tmp, - DenseTensor& dk) { + DenseTensor* dk) { const size_t reduceDimSize = dk_tmp.dims()[3]; const size_t blockNum = std::min((dk_tmp.dims()[0] + 127) / 128, 1024l); // here implicitly flat [batch,seqlen], and require batch dim to be contiguous - SumStridedKV - <<>>((T*)dk_tmp.data(), - (T*)dk.data(), - dk_tmp.dims()[0] * dk_tmp.dims()[1], - dk_tmp.dims()[2], - dk_tmp.dims()[4], - dk_tmp.dims()[3], - dk_tmp.strides()[1], - dk_tmp.strides()[2], - dk_tmp.strides()[4], - dk_tmp.strides()[3], - dk.strides()[1], - dk.strides()[2], - dk.strides()[3]); + SumStridedKV<<>>( + reinterpret_cast(dk_tmp.data()), + reinterpret_cast(dk->data()), + dk_tmp.dims()[0] * dk_tmp.dims()[1], + dk_tmp.dims()[2], + dk_tmp.dims()[4], + dk_tmp.dims()[3], + dk_tmp.strides()[1], + dk_tmp.strides()[2], + dk_tmp.strides()[4], + dk_tmp.strides()[3], + dk->strides()[1], + dk->strides()[2], + dk->strides()[3]); } template void FlashAttnUnpaddedGradBaseKernel( @@ -272,13 +275,13 @@ void FlashAttnUnpaddedGradBaseKernel( if (isContiguous(*dk)) phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); else - kvReduceForGQA(ctx, dk_tmp, *dk); + kvReduceForGQA(ctx, dk_tmp, dk); } if (dv) { if (isContiguous(*dv)) phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); else - kvReduceForGQA(ctx, dv_tmp, *dv); + kvReduceForGQA(ctx, dv_tmp, dv); } } #else @@ -342,7 +345,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, } static void sliceFlattenView(const DenseTensor& in, - DenseTensor& out, + DenseTensor* out, int axis, int64_t offset, int64_t sliceLength) { @@ -363,13 +366,13 @@ static void sliceFlattenView(const DenseTensor& in, id++; is++; } - out = DenseTensor{ + *out = DenseTensor{ in.Holder(), DenseTensorMeta{in.dtype(), DDim{dimArr.data(), in.dims().size() - 1}, DDim(strideArr.data(), in.dims().size() - 1)}}; - out.set_offset(in.offset() + - offset * in.strides()[axis] * SizeOf(out.dtype())); + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out.dtype())); } template struct ZeroFunctor { @@ -399,9 +402,9 @@ void FlashAttnVarlenQKVPackedGradKernel( // q,k,v [total_*, num_heads, head_dim] const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 DenseTensor q, k, v; - sliceFlattenView(qkv, q, 1, 0, head_groupnum - 2); - sliceFlattenView(qkv, k, 1, head_groupnum - 2, 1); - sliceFlattenView(qkv, v, 1, head_groupnum - 1, 1); + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); // DenseTensor dqkv_tmp; if (!dqkv) { return; @@ -416,9 +419,9 @@ void FlashAttnVarlenQKVPackedGradKernel( phi::funcs::ElementwiseKernel(ctx, inputs, &outputs, ZeroFunctor()); } DenseTensor dq, dk, dv; - sliceFlattenView(*dqkv, dq, 1, 0, head_groupnum - 2); - sliceFlattenView(*dqkv, dk, 1, head_groupnum - 2, 1); - sliceFlattenView(*dqkv, dv, 1, head_groupnum - 1, 1); + sliceFlattenView(*dqkv, &dq, 1, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 1, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 1, head_groupnum - 1, 1); FlashAttnUnpaddedGradBaseKernel(ctx, q, k, @@ -612,14 +615,14 @@ void FlashAttnGradBaseKernel( if (isContiguous(*dk)) phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); else - kvReduceBatchedForGQA(ctx, dk_tmp, *dk); + kvReduceBatchedForGQA(ctx, dk_tmp, dk); } if (dv) { if (isContiguous(*dv)) phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); else - kvReduceBatchedForGQA(ctx, dv_tmp, *dv); + kvReduceBatchedForGQA(ctx, dv_tmp, dv); } } #else @@ -685,9 +688,9 @@ void FlashAttnQKVPackedGradKernel( // qkv [batchsize, seqlen, nheads/nheads_k+2, nheads_k, head_dim] const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 DenseTensor q, k, v; - sliceFlattenView(qkv, q, 2, 0, head_groupnum - 2); - sliceFlattenView(qkv, k, 2, head_groupnum - 2, 1); - sliceFlattenView(qkv, v, 2, head_groupnum - 1, 1); + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); // DenseTensor dqkv_tmp; if (!dqkv) { return; @@ -697,9 +700,9 @@ void FlashAttnQKVPackedGradKernel( } ctx.template Alloc(dqkv); DenseTensor dq, dk, dv; - sliceFlattenView(*dqkv, dq, 2, 0, head_groupnum - 2); - sliceFlattenView(*dqkv, dk, 2, head_groupnum - 2, 1); - sliceFlattenView(*dqkv, dv, 2, head_groupnum - 1, 1); + sliceFlattenView(*dqkv, &dq, 2, 0, head_groupnum - 2); + sliceFlattenView(*dqkv, &dk, 2, head_groupnum - 2, 1); + sliceFlattenView(*dqkv, &dv, 2, head_groupnum - 1, 1); FlashAttnGradBaseKernel(ctx, q, k, @@ -809,4 +812,4 @@ PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, phi::dtype::float16, phi::dtype::bfloat16) { kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset -} \ No newline at end of file +} diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index cc1ff47b452e9..a7f7693e351cc 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -205,7 +205,7 @@ void FlashAttnUnpaddedKernel( } static void sliceFlattenView(const DenseTensor& in, - DenseTensor& out, + DenseTensor* out, int axis, int64_t offset, int64_t sliceLength) { @@ -226,13 +226,13 @@ static void sliceFlattenView(const DenseTensor& in, id++; is++; } - out = DenseTensor{ + *out = DenseTensor{ in.Holder(), DenseTensorMeta{in.dtype(), DDim{dimArr.data(), in.dims().size() - 1}, DDim(strideArr.data(), in.dims().size() - 1)}}; - out.set_offset(in.offset() + - offset * in.strides()[axis] * SizeOf(out.dtype())); + out->set_offset(in.offset() + + offset * in.strides()[axis] * SizeOf(out.dtype())); } template void FlashAttnVarlenQKVPackedKernel( @@ -258,9 +258,9 @@ void FlashAttnVarlenQKVPackedKernel( #ifdef PADDLE_WITH_FLASHATTN const auto head_groupnum = qkv.dims()[1]; // nheads/nheads_k + 1 + 1 DenseTensor q, k, v; - sliceFlattenView(qkv, q, 1, 0, head_groupnum - 2); - sliceFlattenView(qkv, k, 1, head_groupnum - 2, 1); - sliceFlattenView(qkv, v, 1, head_groupnum - 1, 1); + sliceFlattenView(qkv, &q, 1, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 1, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 1, head_groupnum - 1, 1); FlashAttnUnpaddedBaseKernel(ctx, q, k, @@ -471,9 +471,9 @@ void FlashAttnQKVPackedKernel( #ifdef PADDLE_WITH_FLASHATTN const auto head_groupnum = qkv.dims()[2]; // nheads/nheads_k + 1 + 1 DenseTensor q, k, v; - sliceFlattenView(qkv, q, 2, 0, head_groupnum - 2); - sliceFlattenView(qkv, k, 2, head_groupnum - 2, 1); - sliceFlattenView(qkv, v, 2, head_groupnum - 1, 1); + sliceFlattenView(qkv, &q, 2, 0, head_groupnum - 2); + sliceFlattenView(qkv, &k, 2, head_groupnum - 2, 1); + sliceFlattenView(qkv, &v, 2, head_groupnum - 1, 1); FlashAttnBaseKernel(ctx, q, k, @@ -583,4 +583,4 @@ PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, phi::dtype::bfloat16) { kernel->InputAt(4).SetBackend( phi::Backend::ALL_BACKEND); // fixed_seed_offset -} \ No newline at end of file +} diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 0cce07d837daf..c61636d98f081 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -72,7 +72,9 @@ def _math_attention( query = paddle.transpose(query, [0, 2, 1, 3]) key = paddle.transpose(key, [0, 2, 1, 3]) value = paddle.transpose(value, [0, 2, 1, 3]) - product = paddle.matmul(x=query * (head_dim**-0.5), y=key, transpose_y=True) + product = paddle.matmul( + x=query * (head_dim**-0.5), y=key, transpose_y=True + ) if not causal: weights = F.softmax(product) @@ -371,7 +373,12 @@ def flash_attn_qkvpacked( if sdp_func_name == "flash_attn": if in_dynamic_or_pir_mode(): - (result_attention, result_softmax, _, _) = _C_ops.flash_attn_qkvpacked( + ( + result_attention, + result_softmax, + _, + _, + ) = _C_ops.flash_attn_qkvpacked( qkv, fixed_seed_offset, None, diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index c8f9509bd75e8..eb074da1e8929 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -26,9 +26,9 @@ from paddle.nn.functional.flash_attention import ( flash_attention, flash_attention_with_sparse_mask, + flash_attn_qkvpacked, flash_attn_unpadded, flash_attn_varlen_qkvpacked, - flash_attn_qkvpacked, scaled_dot_product_attention, ) from paddle.pir_utils import test_with_pir_api From 3122a6fde8fd68dac993e4a956041b94d7814623 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Mon, 8 Apr 2024 10:51:23 +0800 Subject: [PATCH 03/16] fix codestyle --- paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 2c2df92678ca5..4bec6b1825374 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -372,7 +372,7 @@ static void sliceFlattenView(const DenseTensor& in, DDim{dimArr.data(), in.dims().size() - 1}, DDim(strideArr.data(), in.dims().size() - 1)}}; out->set_offset(in.offset() + - offset * in.strides()[axis] * SizeOf(out.dtype())); + offset * in.strides()[axis] * SizeOf(out->dtype())); } template struct ZeroFunctor { diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index a7f7693e351cc..64eb8450bcac6 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -232,7 +232,7 @@ static void sliceFlattenView(const DenseTensor& in, DDim{dimArr.data(), in.dims().size() - 1}, DDim(strideArr.data(), in.dims().size() - 1)}}; out->set_offset(in.offset() + - offset * in.strides()[axis] * SizeOf(out.dtype())); + offset * in.strides()[axis] * SizeOf(out->dtype())); } template void FlashAttnVarlenQKVPackedKernel( From 1c9670592bb766dfd75bcd0228dbb0ec351c5bc3 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Mon, 8 Apr 2024 19:05:57 +0800 Subject: [PATCH 04/16] FlashAttention kvReduceGQA Performance Optimization --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 175 +++++++++++++----- 1 file changed, 133 insertions(+), 42 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4bec6b1825374..23f45564327d8 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include #include "glog/logging.h" // For VLOG() +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/bfloat16.h" @@ -33,7 +34,7 @@ int get_num_split() { // 0 for an internal heuristic, which is optimal return FLAGS_cudnn_deterministic ? 1 : 0; } -bool isContiguous(const DenseTensor& t) { +static bool isContiguous(const DenseTensor& t) { auto rank = t.dims().size(); auto s = t.strides()[rank - 1]; if (s != 1) return false; @@ -46,41 +47,114 @@ bool isContiguous(const DenseTensor& t) { } return true; } -template -__global__ void SumStridedKV(const T* src, - T* dst, - size_t sRowDim1, - size_t sRowDim2, - size_t sRowDim3, - size_t sColDim, - size_t sRowStride1, - size_t sRowStride2, - size_t sRowStride3, - size_t sColStride, - size_t dRowStride1, - size_t dRowStride2, - size_t dRowStride3) { - for (size_t row1 = blockIdx.x; row1 < sRowDim1; row1 += gridDim.x) - for (size_t row2 = 0; row2 < sRowDim2; row2++) - for (size_t row3 = threadIdx.x; row3 < sRowDim3; row3 += blockDim.x) { - T v{0}; - for (size_t col = 0; col < sColDim; col++) { - v += src[row1 * sRowStride1 + row2 * sRowStride2 + - row3 * sRowStride3 + col * sColStride]; +template +static __global__ void SumStridedKV(const T* src, + T* dst, + const uint64_t sRowDim1, + const uint64_t sRowDim2, + const uint64_t sRowDim3, + const uint64_t sColDim, + const uint64_t sRowStride1, + const uint64_t sRowStride2, + const uint64_t sColStride, + const uint64_t dRowStride1, + const uint64_t dRowStride2) { + // SrcShape [seqlen, num_heads_k, num_heads/num_heads_k, headdim] + // AxisName [row1 , row2 , col , row3 ] + // LoopMap [blockx, thready , serialreduce , threadx] + // Ensure blockDim.x == 32 && blockDim.z == 1 + // Ensure sRowStride3 == dRowStride3 == 1 (headdim dim is contiguous) + using IndexType = uint64_t; + constexpr IndexType BlockDimX = 32; + const IndexType SRow1Begin = blockIdx.x * sRowStride1; + const IndexType SRow1End = sRowDim1 * sRowStride1; + const IndexType SRow1Stride = gridDim.x * sRowStride1; + + const IndexType SRow2Begin = threadIdx.y * sRowStride2; + const IndexType SRow2End = sRowDim2 * sRowStride2; + const IndexType SRow2Stride = blockDim.y * sRowStride2; + + // const IndexType SRow3Begin = threadIdx.x * sRowStride3; + // const IndexType SRow3End = sRowDim3 * sRowStride3; + // const IndexType SRow3Stride = BlockDimX * sRowStride3; + + constexpr IndexType SColBegin = 0; + const IndexType SColEnd = sColDim * sColStride; + const IndexType SColStride = sColStride; + + const IndexType DRow1Begin = blockIdx.x * dRowStride1; + const IndexType DRow1Stride = gridDim.x * dRowStride1; + + const IndexType DRow2Begin = threadIdx.y * dRowStride2; + const IndexType DRow2Stride = dRowStride2; + + // const IndexType DRow3Begin = threadIdx.x * dRowStride3; + // const IndexType DRow3Stride = blockDim.x * dRowStride3; + + for (auto row1 = SRow1Begin, drow1 = DRow1Begin; row1 < SRow1End; + row1 += SRow1Stride, drow1 += DRow1Stride) { + for (auto row2 = SRow2Begin, drow2 = DRow2Begin; row2 < SRow2End; + row2 += SRow2Stride, drow2 += DRow2Stride) { + const auto i1 = row1 + row2 + threadIdx.x; + const auto di1 = drow1 + drow2 + threadIdx.x; + T v[HeaddimDiv32]; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] = T{0}; + } + for (auto col = SColBegin; col < SColEnd; col += SColStride) { + const auto i2 = i1 + col; +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + v[i] += src[i2 + i * BlockDimX]; } - dst[row1 * dRowStride1 + row2 * dRowStride2 + row3 * dRowStride3] = v; } +#pragma unroll + for (auto i = IndexType(0); i < HeaddimDiv32; i++) { + dst[di1 + i * BlockDimX] = v[i]; + } + } + } +} + +template +static auto selectSumkernel(int64_t headdim) { + PADDLE_ENFORCE_LE(headdim, 256, "FlashAttention only support headdim <= 256"); + PADDLE_ENFORCE_EQ( + headdim % 32, 0, "FlashAttention only support headdim %% 32 == 0"); + PADDLE_ENFORCE_NE(headdim, 0, "Headdim can't be zero"); +#define CASEN(n) \ + case n: \ + return SumStridedKV; + switch (headdim / 32) { + CASEN(1); + CASEN(2); + CASEN(3); + CASEN(4); + CASEN(5); + CASEN(6); + CASEN(7); + CASEN(8); + } + PADDLE_FATAL("Unreachable in selectSumKernel"); +#undef CASEN } template -void kvReduceForGQA(const Context& ctx, - const DenseTensor& dk_tmp, - DenseTensor* dk) { - const size_t reduceDimSize = dk_tmp.dims()[2]; +static void kvReduceForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[2], 1, "headdim dimention must be contiguous"); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[3], 1, "headdim dimention must be contiguous"); + const int64_t reduceDimSize = dk_tmp.dims()[2]; const size_t blockNum = - std::min((static_cast(dk_tmp.dims()[0] + 127) / 128), - static_cast(1024l)); - SumStridedKV<<>>( + std::min((static_cast(dk_tmp.dims()[0] + 31) / 32), + static_cast(1024l)); + constexpr dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[3]); + sumkernel<<>>( reinterpret_cast(dk_tmp.data()), reinterpret_cast(dk->data()), dk_tmp.dims()[0], @@ -89,20 +163,35 @@ void kvReduceForGQA(const Context& ctx, dk_tmp.dims()[2], dk_tmp.strides()[0], dk_tmp.strides()[1], - dk_tmp.strides()[3], + // dk_tmp.strides()[3], dk_tmp.strides()[2], dk->strides()[0], - dk->strides()[1], - dk->strides()[2]); + dk->strides()[1] + // dk->strides()[2] + ); } template -void kvReduceBatchedForGQA(const Context& ctx, - const DenseTensor& dk_tmp, - DenseTensor* dk) { - const size_t reduceDimSize = dk_tmp.dims()[3]; - const size_t blockNum = std::min((dk_tmp.dims()[0] + 127) / 128, 1024l); +static void kvReduceBatchedForGQA(const Context& ctx, + const DenseTensor& dk_tmp, + DenseTensor* dk) { + PADDLE_ENFORCE_EQ( + dk->strides()[3], 1, "headdim dimention must be contiguous"); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[4], 1, "headdim dimention must be contiguous"); + PADDLE_ENFORCE_EQ(dk->strides()[0], + dk->strides()[1] * dk->dims()[1], + "batchsize dimention must be contiguous"); + PADDLE_ENFORCE_EQ(dk_tmp.strides()[0], + dk_tmp.strides()[1] * dk_tmp.dims()[1], + "batchsize dimention must be contiguous"); + const int64_t reduceDimSize = dk_tmp.dims()[3]; + const size_t blockNum = std::min( + (static_cast(dk_tmp.dims()[0] * dk_tmp.dims()[1] + 31) / 32), + static_cast(1024l)); + constexpr dim3 threadNum{32, 4, 1}; + auto sumkernel = selectSumkernel(dk_tmp.dims()[4]); // here implicitly flat [batch,seqlen], and require batch dim to be contiguous - SumStridedKV<<>>( + sumkernel<<>>( reinterpret_cast(dk_tmp.data()), reinterpret_cast(dk->data()), dk_tmp.dims()[0] * dk_tmp.dims()[1], @@ -111,12 +200,14 @@ void kvReduceBatchedForGQA(const Context& ctx, dk_tmp.dims()[3], dk_tmp.strides()[1], dk_tmp.strides()[2], - dk_tmp.strides()[4], + // dk_tmp.strides()[4], dk_tmp.strides()[3], dk->strides()[1], - dk->strides()[2], - dk->strides()[3]); + dk->strides()[2] + // dk->strides()[3] + ); } + template void FlashAttnUnpaddedGradBaseKernel( const Context& ctx, From d7b113c28dd0441e94155336245c6573dcd56163 Mon Sep 17 00:00:00 2001 From: Xiao Xiyuan <945428667@qq.com> Date: Mon, 8 Apr 2024 20:39:28 +0800 Subject: [PATCH 05/16] Fix problem with windows --- paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 23f45564327d8..6583012d60c47 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -152,7 +152,7 @@ static void kvReduceForGQA(const Context& ctx, const size_t blockNum = std::min((static_cast(dk_tmp.dims()[0] + 31) / 32), static_cast(1024l)); - constexpr dim3 threadNum{32, 4, 1}; + const dim3 threadNum{32, 4, 1}; auto sumkernel = selectSumkernel(dk_tmp.dims()[3]); sumkernel<<>>( reinterpret_cast(dk_tmp.data()), @@ -188,7 +188,7 @@ static void kvReduceBatchedForGQA(const Context& ctx, const size_t blockNum = std::min( (static_cast(dk_tmp.dims()[0] * dk_tmp.dims()[1] + 31) / 32), static_cast(1024l)); - constexpr dim3 threadNum{32, 4, 1}; + const dim3 threadNum{32, 4, 1}; auto sumkernel = selectSumkernel(dk_tmp.dims()[4]); // here implicitly flat [batch,seqlen], and require batch dim to be contiguous sumkernel<<>>( From ee084b21787176f311c8a5b7da9c3e0e9ce958fc Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Tue, 9 Apr 2024 14:54:11 +0800 Subject: [PATCH 06/16] code clean --- paddle/phi/core/kernel_factory.cc | 7 +++--- paddle/phi/core/kernel_factory.h | 3 --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 22 +++++-------------- .../paddle/nn/functional/flash_attention.py | 2 +- 4 files changed, 9 insertions(+), 25 deletions(-) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 0bf8403f15016..32644cfe8bf63 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -270,8 +270,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( auto kernel_iter = iter->second.find( {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()}); if (kernel_iter != iter->second.end()) { - return { - kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; + return {kernel_iter->second, false, false}; } kernel_key = KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype()); @@ -352,7 +351,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( << ", expected_kernel_key:" << kernel_key << ", fallbacking to CPU one!"; - return {kernel_iter->second, true, kernel_iter->second.IsSupportStride()}; + return {kernel_iter->second, true, false}; } PADDLE_ENFORCE_NE( @@ -367,7 +366,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_name, KernelSelectionErrorMessage(kernel_name, kernel_key))); - return {kernel_iter->second, false, kernel_iter->second.IsSupportStride()}; + return {kernel_iter->second, false, false}; } const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 5960a4dc3bc59..4bdb8482c43b7 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -281,8 +281,6 @@ class Kernel { return kernel_registered_type_; } - bool IsSupportStride() const { return support_stride_; } - void SetSupportStride(bool support) { support_stride_ = support; } GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr}; std::function check_if_onednn_kernel_support_{ nullptr}; @@ -292,7 +290,6 @@ class Kernel { void* variadic_fn_ = nullptr; KernelArgsDef args_def_; KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION; - bool support_stride_ = false; }; using KernelKeyMap = paddle::flat_hash_map; diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 23f45564327d8..7d15483824eb5 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -34,19 +34,7 @@ int get_num_split() { // 0 for an internal heuristic, which is optimal return FLAGS_cudnn_deterministic ? 1 : 0; } -static bool isContiguous(const DenseTensor& t) { - auto rank = t.dims().size(); - auto s = t.strides()[rank - 1]; - if (s != 1) return false; - for (auto i = rank - 1; i > 0;) { - s *= t.dims()[i]; - i--; - if (t.strides()[i] != s) { - return false; - } - } - return true; -} + template static __global__ void SumStridedKV(const T* src, T* dst, @@ -363,13 +351,13 @@ void FlashAttnUnpaddedGradBaseKernel( CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - if (isContiguous(*dk)) + if (dk->meta().is_contiguous()) phi::SumKernel(ctx, dk_tmp, {2}, dk->type(), false, dk); else kvReduceForGQA(ctx, dk_tmp, dk); } if (dv) { - if (isContiguous(*dv)) + if (dv->meta().is_contiguous()) phi::SumKernel(ctx, dv_tmp, {2}, dv->type(), false, dv); else kvReduceForGQA(ctx, dv_tmp, dv); @@ -703,14 +691,14 @@ void FlashAttnGradBaseKernel( CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { - if (isContiguous(*dk)) + if (dk->meta().is_contiguous()) phi::SumKernel(ctx, dk_tmp, {3}, dk->type(), false, dk); else kvReduceBatchedForGQA(ctx, dk_tmp, dk); } if (dv) { - if (isContiguous(*dv)) + if (dv->meta().is_contiguous()) phi::SumKernel(ctx, dv_tmp, {3}, dv->type(), false, dv); else kvReduceBatchedForGQA(ctx, dv_tmp, dv); diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index c61636d98f081..f2045c71b7b74 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -353,7 +353,7 @@ def flash_attn_qkvpacked( >>> paddle.seed(2023) >>> q = paddle.rand((1, 128, 2, 16)) - >>> qkv = paddle.stack([q,q,q], axis=3) + >>> qkv = paddle.stack([q,q,q], axis=2) >>> output = paddle.nn.functional.flash_attention.flash_attn_qkvpacked(qkv, 0.9, False, False) >>> print(output) (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, From 78e0a12c5504260160f8e3c65c590a747649c21b Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Tue, 23 Apr 2024 17:48:33 +0800 Subject: [PATCH 07/16] update third_party/flashattn --- third_party/flashattn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/flashattn b/third_party/flashattn index 2521ed0f49750..22b604199d911 160000 --- a/third_party/flashattn +++ b/third_party/flashattn @@ -1 +1 @@ -Subproject commit 2521ed0f49750e45cb0a9769aaf445e28d776809 +Subproject commit 22b604199d911d4e155fe9e54124148c7a290263 From be357d4e1fae6a0d710e700b9aa49fa358a4e8eb Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Thu, 25 Apr 2024 16:36:07 +0800 Subject: [PATCH 08/16] update errormsg and docs --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 44 +++++++++++++------ .../paddle/nn/functional/flash_attention.py | 1 + 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4642bb4219eef..1e919c122bf03 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -107,10 +107,16 @@ static __global__ void SumStridedKV(const T* src, template static auto selectSumkernel(int64_t headdim) { - PADDLE_ENFORCE_LE(headdim, 256, "FlashAttention only support headdim <= 256"); - PADDLE_ENFORCE_EQ( - headdim % 32, 0, "FlashAttention only support headdim %% 32 == 0"); - PADDLE_ENFORCE_NE(headdim, 0, "Headdim can't be zero"); + PADDLE_ENFORCE_LE(headdim, + 256, + phi::errors::InvalidArgument( + "FlashAttention only support headdim <= 256")); + PADDLE_ENFORCE_EQ(headdim % 32, + 0, + phi::errors::InvalidArgument( + "FlashAttention only support headdim %% 32 == 0")); + PADDLE_ENFORCE_NE( + headdim, 0, phi::errors::InvalidArgument("Headdim can't be zero")); #define CASEN(n) \ case n: \ return SumStridedKV; @@ -133,9 +139,13 @@ static void kvReduceForGQA(const Context& ctx, const DenseTensor& dk_tmp, DenseTensor* dk) { PADDLE_ENFORCE_EQ( - dk->strides()[2], 1, "headdim dimention must be contiguous"); + dk->strides()[2], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); PADDLE_ENFORCE_EQ( - dk_tmp.strides()[3], 1, "headdim dimention must be contiguous"); + dk_tmp.strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); const int64_t reduceDimSize = dk_tmp.dims()[2]; const size_t blockNum = std::min((static_cast(dk_tmp.dims()[0] + 31) / 32), @@ -163,15 +173,21 @@ static void kvReduceBatchedForGQA(const Context& ctx, const DenseTensor& dk_tmp, DenseTensor* dk) { PADDLE_ENFORCE_EQ( - dk->strides()[3], 1, "headdim dimention must be contiguous"); + dk->strides()[3], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[4], + 1, + phi::errors::InvalidArgument("headdim dimention must be contiguous")); PADDLE_ENFORCE_EQ( - dk_tmp.strides()[4], 1, "headdim dimention must be contiguous"); - PADDLE_ENFORCE_EQ(dk->strides()[0], - dk->strides()[1] * dk->dims()[1], - "batchsize dimention must be contiguous"); - PADDLE_ENFORCE_EQ(dk_tmp.strides()[0], - dk_tmp.strides()[1] * dk_tmp.dims()[1], - "batchsize dimention must be contiguous"); + dk->strides()[0], + dk->strides()[1] * dk->dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); + PADDLE_ENFORCE_EQ( + dk_tmp.strides()[0], + dk_tmp.strides()[1] * dk_tmp.dims()[1], + phi::errors::InvalidArgument("batchsize dimention must be contiguous")); const int64_t reduceDimSize = dk_tmp.dims()[3]; const size_t blockNum = std::min( (static_cast(dk_tmp.dims()[0] * dk_tmp.dims()[1] + 31) / 32), diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index f2045c71b7b74..076e0f844d6c1 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -324,6 +324,7 @@ def flash_attn_qkvpacked( Warning: This API is only support inputs with dtype float16 and bfloat16. + Don't call this API if flash_attn is not supported. Args: qkv(Tensor): The query/key/value packed tensor in the Attention module. From 4fc88bebe3eac52b6a7581f3a4267524ebe2600c Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Thu, 25 Apr 2024 16:45:33 +0800 Subject: [PATCH 09/16] update api --- python/paddle/nn/functional/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a929088753376..7722ffb437389 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -87,6 +87,8 @@ ) from .flash_attention import ( flash_attention_with_sparse_mask, + flash_attn_qkvpacked, + flash_attn_varlen_qkvpacked, scaled_dot_product_attention, sdp_kernel, # noqa: F401 ) @@ -279,5 +281,7 @@ 'gaussian_nll_loss', 'scaled_dot_product_attention', 'flash_attention_with_sparse_mask', + 'flash_attn_qkvpacked', + 'flash_attn_varlen_qkvpacked', 'group_norm', ] From b0d9af3e31ac7f865d25502feba027c2ecdb6121 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Thu, 25 Apr 2024 16:58:10 +0800 Subject: [PATCH 10/16] update doc --- python/paddle/nn/functional/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 076e0f844d6c1..8c2f05062e933 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -355,7 +355,7 @@ def flash_attn_qkvpacked( >>> paddle.seed(2023) >>> q = paddle.rand((1, 128, 2, 16)) >>> qkv = paddle.stack([q,q,q], axis=2) - >>> output = paddle.nn.functional.flash_attention.flash_attn_qkvpacked(qkv, 0.9, False, False) + >>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False) >>> print(output) (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, [[[[0.34992966, 0.34456208, 0.45826620, ..., 0.39883569, @@ -658,7 +658,7 @@ def flash_attn_varlen_qkvpacked( >>> cu = paddle.arange(0, 384, 128, dtype='int32') >>> qq = paddle.reshape(q, [256, 8, 16]) >>> qkv = paddle.stack([qq,qq,qq], axis=2) - >>> output = paddle.nn.functional.flash_attention.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) + >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) """ if in_dynamic_mode(): From b708c0b7038822b22c814cbd583bafc6434571e8 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Thu, 25 Apr 2024 20:53:21 +0800 Subject: [PATCH 11/16] update doctest --- python/paddle/nn/functional/flash_attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 8c2f05062e933..e229540f01e6b 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -350,6 +350,7 @@ def flash_attn_qkvpacked( Examples: .. code-block:: python + >>> # doctest: +SKIP('flash_attn need A100 compile') >>> import paddle >>> paddle.seed(2023) @@ -367,7 +368,7 @@ def flash_attn_qkvpacked( 0.53336465, 0.54540104], [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, 0.40526697, 0.60541755]]]]), None) - + >>> # doctest: -SKIP """ head_dim = qkv.shape[-1] sdp_func_name = _select_sdp(head_dim) @@ -652,6 +653,7 @@ def flash_attn_varlen_qkvpacked( Examples: .. code-block:: python + >>> # doctest: +SKIP('flash_attn need A100 compile') >>> import paddle >>> paddle.seed(2023) >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') @@ -659,6 +661,7 @@ def flash_attn_varlen_qkvpacked( >>> qq = paddle.reshape(q, [256, 8, 16]) >>> qkv = paddle.stack([qq,qq,qq], axis=2) >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) + >>> # doctest: -SKIP """ if in_dynamic_mode(): From 2dfbba856b3959e94edfe84f45a4ebbc13c2bf37 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Fri, 26 Apr 2024 14:37:23 +0800 Subject: [PATCH 12/16] update doc, test=document_fix --- python/paddle/nn/functional/flash_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index e229540f01e6b..6922e8e95c4cf 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -369,6 +369,7 @@ def flash_attn_qkvpacked( [0.59137970, 0.51350880, 0.50449550, ..., 0.38860250, 0.40526697, 0.60541755]]]]), None) >>> # doctest: -SKIP + """ head_dim = qkv.shape[-1] sdp_func_name = _select_sdp(head_dim) From 4b1c5678576cbe2e10140e7010050525594995a4 Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Sun, 28 Apr 2024 19:18:56 +0800 Subject: [PATCH 13/16] update doc, test=document_fix --- .../paddle/nn/functional/flash_attention.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 6922e8e95c4cf..197a17c3e5e97 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -323,7 +323,7 @@ def flash_attn_qkvpacked( ``d`` represents the size of the last dimension of the three parameters. Warning: - This API is only support inputs with dtype float16 and bfloat16. + This API only supports inputs with dtype float16 and bfloat16. Don't call this API if flash_attn is not supported. Args: @@ -342,9 +342,7 @@ def flash_attn_qkvpacked( :ref:`api_guide_Name`. Returns: - out(Tensor): The attention tensor. - 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. - The dtype can be float16 or bfloat16. + out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. softmax(Tensor): The softmax tensor. None if return_softmax is False. Examples: @@ -355,7 +353,7 @@ def flash_attn_qkvpacked( >>> paddle.seed(2023) >>> q = paddle.rand((1, 128, 2, 16)) - >>> qkv = paddle.stack([q,q,q], axis=2) + >>> qkv = paddle.stack([q, q, q], axis=2) >>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False) >>> print(output) (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True, @@ -516,6 +514,9 @@ def flash_attn_unpadded( :ref:`api_guide_Name`. Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. @@ -621,7 +622,7 @@ def flash_attn_varlen_qkvpacked( ``d`` represents the size of the last dimension of the three parameters. Warning: - This API is only support inputs with dtype float16 and bfloat16. + This API only supports inputs with dtype float16 and bfloat16. Args: qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed @@ -646,9 +647,7 @@ def flash_attn_varlen_qkvpacked( :ref:`api_guide_Name`. Returns: - out(Tensor): The attention tensor. The tensor is padded by zeros. - 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. - The dtype can be float16 or bfloat16. + out(Tensor): The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. softmax(Tensor): The softmax tensor. None if return_softmax is False. Examples: @@ -660,7 +659,7 @@ def flash_attn_varlen_qkvpacked( >>> q = paddle.rand((2, 128, 8, 16), dtype='float16') >>> cu = paddle.arange(0, 384, 128, dtype='int32') >>> qq = paddle.reshape(q, [256, 8, 16]) - >>> qkv = paddle.stack([qq,qq,qq], axis=2) + >>> qkv = paddle.stack([qq, qq, qq], axis=2) >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False) >>> # doctest: -SKIP From 1ddd78d992d13368043c31c11a244492591ff887 Mon Sep 17 00:00:00 2001 From: Xiao Xiyuan <945428667@qq.com> Date: Mon, 29 Apr 2024 11:08:59 +0800 Subject: [PATCH 14/16] Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> --- python/paddle/nn/functional/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 197a17c3e5e97..f9ac7743cb89f 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -647,8 +647,8 @@ def flash_attn_varlen_qkvpacked( :ref:`api_guide_Name`. Returns: - out(Tensor): The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. - softmax(Tensor): The softmax tensor. None if return_softmax is False. + - out(Tensor). The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. Examples: .. code-block:: python From 7091c846f883e028be7119fa87b38a14b21e063e Mon Sep 17 00:00:00 2001 From: Xiao Xiyuan <945428667@qq.com> Date: Mon, 29 Apr 2024 11:09:09 +0800 Subject: [PATCH 15/16] Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> --- python/paddle/nn/functional/flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index f9ac7743cb89f..5bb0cc8f143a5 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -342,8 +342,8 @@ def flash_attn_qkvpacked( :ref:`api_guide_Name`. Returns: - out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. - softmax(Tensor): The softmax tensor. None if return_softmax is False. + - out(Tensor). The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16. + - softmax(Tensor). The softmax tensor. None if return_softmax is False. Examples: .. code-block:: python From 0376214b3570149ef4f3dbc668487f4ba3ef7bff Mon Sep 17 00:00:00 2001 From: kircle <945428667@qq.com> Date: Mon, 29 Apr 2024 11:10:51 +0800 Subject: [PATCH 16/16] update doc --- python/paddle/nn/functional/flash_attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 5bb0cc8f143a5..84c7882a7151d 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -514,9 +514,6 @@ def flash_attn_unpadded( :ref:`api_guide_Name`. Returns: - out(Tensor): The attention tensor. - 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. - The dtype can be float16 or bfloat16. out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.