Skip to content

Commit

Permalink
Support FlashAttentionWithSparseMask (PaddlePaddle#35)
Browse files Browse the repository at this point in the history
* support sparse mask

* fix attn mask index more than uint32

* using share memory, support share head, skip no mask block

* fix fwd sparse mask row_offset_sparse_mask and opt apply_sparse_mask_causal
  • Loading branch information
GuoxiaWang authored Feb 21, 2024
1 parent 5fc132a commit d98d8a3
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 18 deletions.
32 changes: 28 additions & 4 deletions csrc/capi/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ void set_params_fprop(Flash_fwd_params &params,
bool is_causal,
bool is_bf16,
void * attn_mask = nullptr,
void * attn_mask_start_row_indices = nullptr,
const int attn_mask_start_row = 0,
int mask_head_mod_size = 0,
int mask_seq_q_mod_size = 0) {
// Reset the parameters
Expand Down Expand Up @@ -169,6 +171,10 @@ void set_params_fprop(Flash_fwd_params &params,
params.mask_head_mod_size = mask_head_mod_size;
params.mask_seq_q_mod_size = mask_seq_q_mod_size;

// sparse mask row index
params.attn_mask_start_row_indices_ptr = attn_mask_start_row_indices;
params.attn_mask_start_row = attn_mask_start_row;

// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
Expand Down Expand Up @@ -222,6 +228,8 @@ void set_params_dgrad(Flash_bwd_params &params,
bool is_bf16,
const int num_splits = 0,
void * attn_mask = nullptr,
void * attn_mask_start_row_indices = nullptr,
const int attn_mask_start_row = 0,
int mask_head_mod_size = 0,
int mask_seq_q_mod_size = 0) {

Expand All @@ -238,6 +246,8 @@ void set_params_dgrad(Flash_bwd_params &params,
is_causal,
is_bf16,
attn_mask,
attn_mask_start_row_indices,
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);

Expand Down Expand Up @@ -309,10 +319,13 @@ bool flash_attn_fwd(const void * const q,
uint64_t seed,
uint64_t offset,
const void * const attn_mask,
const int64_t * const mask_dims) {
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const int attn_mask_start_row) {
FLASHATTNLIB_BEGIN_FUNC
const bool is_dropout = p_dropout > 0.0;
const int mask_head_mod_size = attn_mask ? mask_dims[1] : 0;
const int mask_head_mod_size = attn_mask ? mask_dims[1] : attn_mask_start_row_indices ? attn_mask_start_row_indices_dims[1] : 0;
const int mask_seq_q_mod_size = attn_mask ? mask_dims[2] : 0;

CHECK_FWD_EXECTUABLE(seqlen_q, seqlen_k)
Expand All @@ -338,6 +351,8 @@ bool flash_attn_fwd(const void * const q,
is_causal,
is_bf16,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);

Expand Down Expand Up @@ -414,6 +429,8 @@ bool flash_attn_varlen_fwd(const void * const q,
is_causal,
is_bf16,
const_cast<void *>(attn_mask),
nullptr,
-1,
mask_head_mod_size,
mask_seq_q_mod_size);

Expand Down Expand Up @@ -483,10 +500,13 @@ bool flash_attn_bwd(const void * const dout,
uint64_t seed,
uint64_t offset,
const void * const attn_mask,
const int64_t * const mask_dims) {
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const int attn_mask_start_row) {
FLASHATTNLIB_BEGIN_FUNC
const bool is_dropout = p_dropout > 0.0;
const int mask_head_mod_size = attn_mask ? mask_dims[1] : 0;
const int mask_head_mod_size = attn_mask ? mask_dims[1] : attn_mask_start_row_indices ? attn_mask_start_row_indices_dims[1] : 0;
const int mask_seq_q_mod_size = attn_mask ? mask_dims[2] : 0;

CHECK_BWD_EXECTUABLE(seqlen_q, seqlen_k)
Expand Down Expand Up @@ -525,6 +545,8 @@ bool flash_attn_bwd(const void * const dout,
is_bf16,
num_splits,
const_cast<void *>(attn_mask),
const_cast<void *>(attn_mask_start_row_indices),
attn_mask_start_row,
mask_head_mod_size,
mask_seq_q_mod_size);

Expand Down Expand Up @@ -619,6 +641,8 @@ bool flash_attn_varlen_bwd(const void * const dout,
is_bf16,
num_splits,
const_cast<void *>(attn_mask),
nullptr,
-1,
mask_head_mod_size,
mask_seq_q_mod_size);

Expand Down
10 changes: 8 additions & 2 deletions csrc/capi/flash_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num
uint64_t seed,
uint64_t offset,
const void * const attn_mask,
const int64_t * const mask_dims);
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const int attn_mask_start_row);

bool flash_attn_varlen_fwd(const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const void * const k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
Expand Down Expand Up @@ -97,7 +100,10 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea
uint64_t seed,
uint64_t offset,
const void * const attn_mask,
const int64_t * const mask_dims);
const int64_t * const mask_dims,
const void * const attn_mask_start_row_indices,
const int64_t * const attn_mask_start_row_indices_dims,
const int attn_mask_start_row);

bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x head_size
const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ attn_mask_ptr;
int mask_head_mod_size;
int mask_seq_q_mod_size;
void * __restrict__ attn_mask_start_row_indices_ptr;
int attn_mask_start_row;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
25 changes: 23 additions & 2 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,15 @@ inline __device__ void convert_dKV(const Params &params) {
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Is_attn_mask, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

const bool Is_sparse_attn_mask = params.attn_mask_start_row_indices_ptr != nullptr;
const int attn_mask_start_row = params.attn_mask_start_row;

using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;

// Shared memory.
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand Down Expand Up @@ -475,11 +479,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
+ (m_block_max - 1) * kBlockM;

const index_t row_offset_mask = ((bidb * params.mask_head_mod_size
const uint64_t row_offset_mask = (uint64_t)((bidb * params.mask_head_mod_size
+ (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size
+ ((m_block_max - 1) * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k
+ n_block * kBlockN;

const index_t row_offset_sparse_mask = (bidb * params.mask_head_mod_size + bidh % params.mask_head_mod_size) * params.seqlen_k + n_block * kBlockN;

Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Expand Down Expand Up @@ -508,6 +514,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k, _1{}));
Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{});
Expand All @@ -531,6 +539,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
// sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});

Expand Down Expand Up @@ -796,6 +805,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
}

if (Is_sparse_attn_mask) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
}
__syncthreads();
}

auto seed = params.rng_state[0];
auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;

Expand Down Expand Up @@ -849,7 +865,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements not beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN

if (Is_sparse_attn_mask && m_block * kBlockM >= attn_mask_start_row) {
flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
AtomLayoutMS * 16, n_block * kBlockN);
} else if (m_block * kBlockM < (n_block + 1) * kBlockN
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_deterministic = params.num_splits == 1;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
params.attn_mask_start_row = (int)(params.attn_mask_start_row / Kernel_traits::kBlockM) * Kernel_traits::kBlockM;
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
Expand Down
52 changes: 45 additions & 7 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ inline __device__ void write_softmax_to_gmem(
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, bool Is_attn_mask, bool Is_equal_seq_qk, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {

const bool Is_sparse_attn_mask = params.attn_mask_start_row_indices_ptr != nullptr;
const int attn_mask_start_row = params.attn_mask_start_row;

using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;

// Shared memory.
__shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN];
extern __shared__ char smem_[];

// The thread index.
Expand Down Expand Up @@ -171,11 +175,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;

const index_t row_offset_mask = ((bidb * params.mask_head_mod_size
const uint64_t row_offset_mask = (uint64_t)((bidb * params.mask_head_mod_size
+ (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size
+ (m_block * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k
+ (n_block_max - 1) * kBlockN;

const index_t row_offset_sparse_mask = (bidb * params.mask_head_mod_size + bidh % params.mask_head_mod_size) * params.seqlen_k + (n_block_max - 1) * kBlockN;

Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Expand All @@ -193,6 +199,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k, _1{}));

Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast<int32_t *>(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask),
Shape<Int<kBlockN>>{});

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Expand All @@ -201,6 +210,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast<int32_t *>(sparse_mask_smem_)), Shape<Int<kBlockN>>{});

typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Expand Down Expand Up @@ -406,12 +416,26 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
if (Is_sparse_attn_mask && m_block * kBlockM >= attn_mask_start_row) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
}
__syncthreads();
flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16, n_block * kBlockN);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
} else {
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
}

flash::cp_async_wait<0>();
Expand Down Expand Up @@ -500,6 +524,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
params.unscale_softmax);
tPgMask.data() = tPgMask.data() + (-kBlockN);
}
if (Is_causal && Is_sparse_attn_mask && m_block * kBlockM >= params.attn_mask_start_row) {
if (tidx < kBlockN) {
sSparseMask(tidx) = gSparseMask(tidx);
}
__syncthreads();
flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16, n_block * kBlockN);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
gSparseMask.data() = gSparseMask.data() + (-kBlockN);
}

if (Is_equal_seq_qk) {
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
} else {
Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool return_softmax = params.p_ptr != nullptr;
const bool is_attn_mask = params.attn_mask_ptr != nullptr;
const bool is_equal_qk = (params.cu_seqlens_q == nullptr) && (params.cu_seqlens_k == nullptr) && (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask);
params.attn_mask_start_row = (int)(params.attn_mask_start_row / Kernel_traits::kBlockM) * Kernel_traits::kBlockM;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
Expand Down
Loading

0 comments on commit d98d8a3

Please sign in to comment.