diff --git a/src/kernels/flash_attn/flash_api.cpp b/src/kernels/flash_attn/flash_api.cpp index 59e3f4df..4190ca4e 100644 --- a/src/kernels/flash_attn/flash_api.cpp +++ b/src/kernels/flash_attn/flash_api.cpp @@ -276,7 +276,7 @@ mha_varlen_fwd(at::Tensor &q, // [n_tokens, n_heads, head_dim] const int n_blocks = !paged_KV ? 0 : k.size(0); const int block_size = !paged_KV ? 1 : k.size(1); // TODO: support smaller block sizes - TORCH_CHECK(!paged_KV || block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); // [n_tokens, n_heads, head_dim] const auto sizes = q.sizes(); diff --git a/src/kernels/flash_attn/src/flash_fwd_kernel.h b/src/kernels/flash_attn/src/flash_fwd_kernel.h index bb02a301..1b659e01 100644 --- a/src/kernels/flash_attn/src/flash_fwd_kernel.h +++ b/src/kernels/flash_attn/src/flash_fwd_kernel.h @@ -515,16 +515,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + : (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -544,15 +543,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); + + Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout())); + Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout())); + Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout())); + Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); + + if (block_table != nullptr) { + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -590,8 +604,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout())); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -608,11 +623,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { + typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. @@ -629,10 +645,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout())); + Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout())); + Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout())); + Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout())); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } @@ -653,8 +676,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new; + auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx); + Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout())); + auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout())); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); @@ -694,14 +722,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } } } @@ -714,9 +738,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. @@ -751,7 +779,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); @@ -790,17 +818,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -825,13 +850,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -868,13 +890,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -889,13 +908,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/src/kernels/flash_attn/src/kernel_traits.h b/src/kernels/flash_attn/src/kernel_traits.h index e1ee4b83..c0fd202a 100644 --- a/src/kernels/flash_attn/src/kernel_traits.h +++ b/src/kernels/flash_attn/src/kernel_traits.h @@ -127,6 +127,18 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + + // from how many rows does each thread have to fetch + static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); + // Here we assign a contiguous tile to each thread, rather than a 1x8 row every + // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread + // do not cross a page boundary. This way, each thread need only fetch 1 page index per + // mainloop iteration. R>udimentary testing shows no slowdown. + using GmemTiledCopyQKVPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout, _8>, Stride<_8, _1>>{})); + using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, @@ -152,6 +164,14 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load + using GmemTiledCopyRotcossinPaged = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinContPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/kernels/flash_attn/src/utils.h b/src/kernels/flash_attn/src/utils.h index 4bcfa7f6..c5d67e74 100644 --- a/src/kernels/flash_attn/src/utils.h +++ b/src/kernels/flash_attn/src/utils.h @@ -379,4 +379,79 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +// resolves initial base offset of a slice of a paged kv copy from gmem. +// assumes that the tensor has already been positioned at the correct head. +template +__forceinline__ __device__ +int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; + constexpr int kBlockN = Kernel_traits::kBlockN; + + const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; + const int page_offset = global_row_offset % page_block_size; + const int virtual_page_idx = global_row_offset / page_block_size; + + return block_table[virtual_page_idx] * page_stride + + page_offset * row_stride + + col_offset; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// advances base address of a slice of a paged copy from gmem +template +__forceinline__ __device__ +int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kBlockN = Kernel_traits::kBlockN; + + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + + const int global_row_offset_cur = block_row_offset + n_block * kBlockN; + const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; + + const int page_offset_cur = global_row_offset_cur % page_block_size; + const int page_offset_next = global_row_offset_next % page_block_size; + + const int virtual_page_idx_cur = global_row_offset_cur / page_block_size; + const int virtual_page_idx_next = global_row_offset_next / page_block_size; + + const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur]; + const int offset_diff = page_offset_next - page_offset_cur; + + return table_diff * page_stride + offset_diff * row_stride; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k), +// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures +// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. +template +__forceinline__ __device__ +auto reshape_thread_tile(Layout l) { + return make_layout(append(get<0>(l.shape()), get<2>(l.shape())), + append(get<0>(l.stride()), get<2>(l.stride()))); +} + +// reshapes and flattens the thread tile layout. A separate function is needed for the case where +// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact +// for the case of swizzled layouts +template +__forceinline__ __device__ +auto reshape_flatten_thread_tile(Layout l) { + auto mode_0 = filter(flatten(get<0>(l))); + return make_layout(append(mode_0.shape(), get<2>(l.shape())), + append(mode_0.stride(), get<2>(l.stride()))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/src/layers/attention_test.cpp b/src/layers/attention_test.cpp index c6c76df9..62d6502b 100644 --- a/src/layers/attention_test.cpp +++ b/src/layers/attention_test.cpp @@ -322,7 +322,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(torch::kCUDA), ::testing::Values(torch::kHalf, torch::kBFloat16), ::testing::Values(1, 10), // batch_size - ::testing::Values(256), // block_size + ::testing::Values(16, 80, 256), // block_size ::testing::Values(1, 10), // q_max_seq_len ::testing::Values(100, 1000), // k_max_seq_len ::testing::Values(6), // n_heads