diff --git a/csrc/capi/flash_attn.cu b/csrc/capi/flash_attn.cu index f256d7ee1b1491..5ef982f723a9df 100644 --- a/csrc/capi/flash_attn.cu +++ b/csrc/capi/flash_attn.cu @@ -114,6 +114,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, bool is_causal, bool is_bf16, void * attn_mask = nullptr, + void * attn_mask_start_row_indices = nullptr, + const int attn_mask_start_row = 0, int mask_head_mod_size = 0, int mask_seq_q_mod_size = 0) { // Reset the parameters @@ -169,6 +171,10 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.mask_head_mod_size = mask_head_mod_size; params.mask_seq_q_mod_size = mask_seq_q_mod_size; + // sparse mask row index + params.attn_mask_start_row_indices_ptr = attn_mask_start_row_indices; + params.attn_mask_start_row = attn_mask_start_row; + // Set the different scale values. params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; @@ -222,6 +228,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, bool is_bf16, const int num_splits = 0, void * attn_mask = nullptr, + void * attn_mask_start_row_indices = nullptr, + const int attn_mask_start_row = 0, int mask_head_mod_size = 0, int mask_seq_q_mod_size = 0) { @@ -238,6 +246,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, is_causal, is_bf16, attn_mask, + attn_mask_start_row_indices, + attn_mask_start_row, mask_head_mod_size, mask_seq_q_mod_size); @@ -309,10 +319,13 @@ bool flash_attn_fwd(const void * const q, uint64_t seed, uint64_t offset, const void * const attn_mask, - const int64_t * const mask_dims) { + const int64_t * const mask_dims, + const void * const attn_mask_start_row_indices, + const int64_t * const attn_mask_start_row_indices_dims, + const int attn_mask_start_row) { FLASHATTNLIB_BEGIN_FUNC const bool is_dropout = p_dropout > 0.0; - const int mask_head_mod_size = attn_mask ? mask_dims[1] : 0; + const int mask_head_mod_size = attn_mask ? mask_dims[1] : attn_mask_start_row_indices ? attn_mask_start_row_indices_dims[1] : 0; const int mask_seq_q_mod_size = attn_mask ? mask_dims[2] : 0; CHECK_FWD_EXECTUABLE(seqlen_q, seqlen_k) @@ -338,6 +351,8 @@ bool flash_attn_fwd(const void * const q, is_causal, is_bf16, const_cast(attn_mask), + const_cast(attn_mask_start_row_indices), + attn_mask_start_row, mask_head_mod_size, mask_seq_q_mod_size); @@ -414,6 +429,8 @@ bool flash_attn_varlen_fwd(const void * const q, is_causal, is_bf16, const_cast(attn_mask), + nullptr, + -1, mask_head_mod_size, mask_seq_q_mod_size); @@ -483,10 +500,13 @@ bool flash_attn_bwd(const void * const dout, uint64_t seed, uint64_t offset, const void * const attn_mask, - const int64_t * const mask_dims) { + const int64_t * const mask_dims, + const void * const attn_mask_start_row_indices, + const int64_t * const attn_mask_start_row_indices_dims, + const int attn_mask_start_row) { FLASHATTNLIB_BEGIN_FUNC const bool is_dropout = p_dropout > 0.0; - const int mask_head_mod_size = attn_mask ? mask_dims[1] : 0; + const int mask_head_mod_size = attn_mask ? mask_dims[1] : attn_mask_start_row_indices ? attn_mask_start_row_indices_dims[1] : 0; const int mask_seq_q_mod_size = attn_mask ? mask_dims[2] : 0; CHECK_BWD_EXECTUABLE(seqlen_q, seqlen_k) @@ -525,6 +545,8 @@ bool flash_attn_bwd(const void * const dout, is_bf16, num_splits, const_cast(attn_mask), + const_cast(attn_mask_start_row_indices), + attn_mask_start_row, mask_head_mod_size, mask_seq_q_mod_size); @@ -619,6 +641,8 @@ bool flash_attn_varlen_bwd(const void * const dout, is_bf16, num_splits, const_cast(attn_mask), + nullptr, + -1, mask_head_mod_size, mask_seq_q_mod_size); diff --git a/csrc/capi/flash_attn.h b/csrc/capi/flash_attn.h index 1cdabe48b68c10..77df7a200a7347 100644 --- a/csrc/capi/flash_attn.h +++ b/csrc/capi/flash_attn.h @@ -34,7 +34,10 @@ bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num uint64_t seed, uint64_t offset, const void * const attn_mask, - const int64_t * const mask_dims); + const int64_t * const mask_dims, + const void * const attn_mask_start_row_indices, + const int64_t * const attn_mask_start_row_indices_dims, + const int attn_mask_start_row); bool flash_attn_varlen_fwd(const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const void * const k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -97,7 +100,10 @@ bool flash_attn_bwd(const void * const dout, // batch_size x seqlen_q x num_hea uint64_t seed, uint64_t offset, const void * const attn_mask, - const int64_t * const mask_dims); + const int64_t * const mask_dims, + const void * const attn_mask_start_row_indices, + const int64_t * const attn_mask_start_row_indices_dims, + const int attn_mask_start_row); bool flash_attn_varlen_bwd(const void * const dout, // total_q x num_heads, x head_size const void * const q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index c88fb73f61f135..b1aa402ff537bf 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -107,6 +107,8 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ attn_mask_ptr; int mask_head_mod_size; int mask_seq_q_mod_size; + void * __restrict__ attn_mask_start_row_indices_ptr; + int attn_mask_start_row; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index e1ff59778adf55..a0399677296e5d 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -427,11 +427,15 @@ inline __device__ void convert_dKV(const Params ¶ms) { template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + const bool Is_sparse_attn_mask = params.attn_mask_start_row_indices_ptr != nullptr; + const int attn_mask_start_row = params.attn_mask_start_row; + using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. + __shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN]; extern __shared__ char smem_[]; // The thread index. @@ -475,11 +479,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + (m_block_max - 1) * kBlockM; - const index_t row_offset_mask = ((bidb * params.mask_head_mod_size + const uint64_t row_offset_mask = (uint64_t)((bidb * params.mask_head_mod_size + (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size + ((m_block_max - 1) * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k + n_block * kBlockN; + const index_t row_offset_sparse_mask = (bidb * params.mask_head_mod_size + bidh % params.mask_head_mod_size) * params.seqlen_k + n_block * kBlockN; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); @@ -508,6 +514,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), Shape, Int>{}, make_stride(params.seqlen_k, _1{})); + Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask), + Shape>{}); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQdO{}); @@ -531,6 +539,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); // sP and sdQ share the same memory so be careful Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast(sparse_mask_smem_)), Shape>{}); Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast((sP.data() + cute::max(size(sP), size(sdQ))).get())), Shape>{}); @@ -796,6 +805,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); } + if (Is_sparse_attn_mask) { + if (tidx < kBlockN) { + sSparseMask(tidx) = gSparseMask(tidx); + } + __syncthreads(); + } + auto seed = params.rng_state[0]; auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; @@ -849,7 +865,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. // But we still want to mask out elements not beyond actual_seqlen_k. - if (m_block * kBlockM < (n_block + 1) * kBlockN + + if (Is_sparse_attn_mask && m_block * kBlockM >= attn_mask_start_row) { + flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + AtomLayoutMS * 16, n_block * kBlockN); + } else if (m_block * kBlockM < (n_block + 1) * kBlockN || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2c62e6c5797df7..442df04418d49b 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); + params.attn_mask_start_row = (int)(params.attn_mask_start_row / Kernel_traits::kBlockM) * Kernel_traits::kBlockM; BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 83dff13ccb5d68..ac58d0af1ff283 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -121,11 +121,15 @@ inline __device__ void write_softmax_to_gmem( template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + const bool Is_sparse_attn_mask = params.attn_mask_start_row_indices_ptr != nullptr; + const int attn_mask_start_row = params.attn_mask_start_row; + using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. + __shared__ int32_t sparse_mask_smem_[Kernel_traits::kBlockN]; extern __shared__ char smem_[]; // The thread index. @@ -171,11 +175,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - const index_t row_offset_mask = ((bidb * params.mask_head_mod_size + const uint64_t row_offset_mask = (uint64_t)((bidb * params.mask_head_mod_size + (bidh % params.mask_head_mod_size)) * params.mask_seq_q_mod_size + (m_block * kBlockM % params.mask_seq_q_mod_size)) * params.seqlen_k + (n_block_max - 1) * kBlockN; + const index_t row_offset_sparse_mask = (bidb * params.mask_head_mod_size + bidh % params.mask_head_mod_size) * params.seqlen_k + (n_block_max - 1) * kBlockN; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); @@ -193,6 +199,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Shape, Int>{}, make_stride(params.seqlen_k, _1{})); + Tensor gSparseMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_start_row_indices_ptr) + row_offset_sparse_mask), + Shape>{}); + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; @@ -201,6 +210,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sSparseMask = make_tensor(make_smem_ptr(reinterpret_cast(sparse_mask_smem_)), Shape>{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -406,12 +416,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + if (Is_sparse_attn_mask && m_block * kBlockM >= attn_mask_start_row) { + if (tidx < kBlockN) { + sSparseMask(tidx) = gSparseMask(tidx); + } + __syncthreads(); + flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16, n_block * kBlockN); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + gSparseMask.data() = gSparseMask.data() + (-kBlockN); + } else { + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + } } flash::cp_async_wait<0>(); @@ -500,6 +524,20 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi params.unscale_softmax); tPgMask.data() = tPgMask.data() + (-kBlockN); } + if (Is_causal && Is_sparse_attn_mask && m_block * kBlockM >= params.attn_mask_start_row) { + if (tidx < kBlockN) { + sSparseMask(tidx) = gSparseMask(tidx); + } + __syncthreads(); + flash::apply_sparse_mask_causal(scores, sSparseMask, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + kNWarps * 16, n_block * kBlockN); + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); + gSparseMask.data() = gSparseMask.data() + (-kBlockN); + } + if (Is_equal_seq_qk) { softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); } else { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index ce162ee1b12d49..43334282055185 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_equal_qk = (params.cu_seqlens_q == nullptr) && (params.cu_seqlens_k == nullptr) && (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); + params.attn_mask_start_row = (int)(params.attn_mask_start_row / Kernel_traits::kBlockM) * Kernel_traits::kBlockM; BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 2e3cb52192f13e..29730c3e0c9944 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -119,7 +119,8 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + static constexpr int kSmemSparseMaskIndicesSize = kBlockN * sizeof(int32_t); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize + kSmemSparseMaskIndicesSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); @@ -312,15 +313,16 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSparseMaskIndicesSize = kBlockN * sizeof(int32_t); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize + static constexpr int kSmemSize1colblock = kSmemSparseMaskIndicesSize + kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + static constexpr int kSmemSize1rowblock = kSmemSparseMaskIndicesSize + kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 1d2307454545ec..148f2fce5b572a 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -177,6 +177,42 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u } } +template +inline __device__ void apply_sparse_mask_causal(Tensor &tensor, Tensor &attn_mask_start_row_indices, const uint32_t col_idx_offset_, + const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, + const uint32_t warp_row_stride, const uint32_t mask_col_idx_offset) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const uint32_t lane_id = threadIdx.x % 32; + // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; + const uint32_t row_idx_offset = row_idx_offset_; + const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const uint32_t col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const uint32_t col_idx = col_idx_base + j; + const uint32_t start_row = attn_mask_start_row_indices(col_idx - mask_col_idx_offset); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const uint32_t row_idx = row_idx_base + i * 8; + const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + else if (col_idx < col_idx_limit - 1 && row_idx >= start_row) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } +} + // TODO(umiswing): support cu_attn_mask // This kernel should work after dealing with input cu_seq indicating mask position. template