From 28cffd366888649a1e9d871efec32e67b88070cb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 29 Jul 2024 05:05:32 -0700 Subject: [PATCH] feat: sliding window attention (#406) As requested in #390 , this PR implements sliding window attention. This PR also result in slight performance degration because we didn't specialize kernels that use/not use sliding window. I believe we can address them by landing the JIT compilation feature. I'll merge this feature first and improve performance in later PRs. --- include/flashinfer/attention/decode.cuh | 45 ++- include/flashinfer/attention/handler.cuh | 4 +- include/flashinfer/attention/prefill.cuh | 83 ++-- include/flashinfer/decode_attention_decl.cuh | 21 +- include/flashinfer/prefill_attention_decl.cuh | 25 +- include/flashinfer/utils.cuh | 7 + python/csrc/batch_decode.cu | 8 +- python/csrc/batch_prefill.cu | 27 +- python/csrc/flashinfer_ops.h | 33 +- python/csrc/single_decode.cu | 9 +- python/csrc/single_prefill.cu | 15 +- python/flashinfer/decode.py | 18 + python/flashinfer/prefill.py | 36 ++ python/generate_batch_paged_decode_inst.py | 2 +- python/generate_batch_paged_prefill_inst.py | 4 +- python/generate_batch_ragged_prefill_inst.py | 3 +- python/generate_single_decode_inst.py | 2 +- python/generate_single_prefill_inst.py | 2 +- python/tests/test_sliding_window.py | 355 ++++++++++++++++++ src/flashinfer_ops.cuh | 10 +- 20 files changed, 586 insertions(+), 123 deletions(-) create mode 100644 python/tests/test_sliding_window.py diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 907a686e..6af2a46d 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -74,9 +74,10 @@ template & q_vec, const vec_t& freq, uint32_t kv_idx_base, - uint32_t iter_base, uint32_t iter_bound, - const int32_t q_offset, float alibi_slope, float* s, - state_t& st, const float logits_soft_cap) { + uint32_t iter_base, uint32_t left_close_bound, + uint32_t iter_bound, const int32_t q_offset, + float alibi_slope, float* s, state_t& st, + const float logits_soft_cap) { uint32_t tx = threadIdx.x, tz = threadIdx.z; float m_prev = st.m; #pragma unroll @@ -100,9 +101,10 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage s[j] += math::shfl_xor_sync(s[j], offset); } s[j] = apply_logits_post_hook(s[j], logits_soft_cap); - s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; + const uint32_t pos = kv_idx_base + tz * tile_size + j; + s[j] = (iter_base + tz * tile_size + j < iter_bound && pos >= left_close_bound) ? s[j] : -5e4; if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { - s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset); + s[j] += alibi_slope * float(int(pos) - q_offset); } st.m = max(st.m, s[j]); } @@ -212,9 +214,9 @@ template = 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0; extern __shared__ uint8_t smem[]; DTypeKV* k_smem = (DTypeKV*)smem; @@ -303,8 +307,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ block.sync(); compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, - freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, - seq_len - 1, alibi_slope, s, st_local, logits_soft_cap); + freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, left_close_bound, + kv_chunk_size, seq_len - 1, alibi_slope, s, st_local, logits_soft_cap); block.sync(); // load k for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -389,8 +393,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap, - float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { + float* __restrict__ lse, bool* __restrict__ block_valid_mask, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { auto block = cg::this_thread_block(); sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); @@ -415,6 +419,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( : 0; const uint32_t seq_len = partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len; + const uint32_t left_close_bound = + (window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0; const uint32_t mapped_batch_idx = partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx; @@ -521,8 +527,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( freq, (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz, - iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st, - logits_soft_cap); + iter * tile_size_per_bdx * bdy * bdz, left_close_bound, kv_chunk_len, q_offset_val, + alibi_slope, s, st, logits_soft_cap); block.sync(); #pragma unroll @@ -627,8 +633,9 @@ template paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; @@ -761,6 +770,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( (void*)&o, (void*)&lse, (void*)&block_valid_mask, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&rope_rcp_scale, @@ -782,6 +792,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( (void*)&tmp_v, (void*)&tmp_s, (void*)&block_valid_mask, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&rope_rcp_scale, diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index af44180c..796a5f8f 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -42,8 +42,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, - float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap, - float sm_scale, float rope_rcp_scale, float rope_rcp_theta); + float* __restrict__ lse, bool* __restrict__ block_valid_mask, int maybe_window_left, + float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta); /*! * \brief Compute the maximum number of pages per batch and the new batch size diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index a4f6cbc6..44adcd0c 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -56,13 +56,6 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags (num_frags_x * (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= 256)); } -/*! - * \brief Return x - y if x > y, otherwise return 0. - */ -__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t y) { - return (x > y) ? x - y : 0U; -} - template __device__ __forceinline__ uint32_t get_warp_idx_x() { if constexpr (num_warps_x == 1) { @@ -588,8 +581,9 @@ template __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, - const uint32_t kv_len, const uint32_t chunk_end, - const uint_fastdiv group_size, uint8_t* custom_mask, + const uint32_t kv_len, const uint32_t window_left, + const uint32_t chunk_end, const uint_fastdiv group_size, + uint8_t* custom_mask, DTypeQKAccum (*s_frag)[num_frags_z][8]) { const uint32_t lane_idx = threadIdx.x; #pragma unroll @@ -605,8 +599,9 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, reg_id % 2; const bool out_of_boundary = (mask_mode == MaskMode::kCausal - ? (kv_idx + qo_len > kv_len + q_idx || (partition_kv && kv_idx >= chunk_end)) - : kv_idx >= chunk_end); + ? (kv_idx + qo_len > kv_len + q_idx || (partition_kv && kv_idx >= chunk_end) || + kv_idx + qo_len + window_left < kv_len + q_idx) + : kv_idx >= chunk_end || kv_idx + qo_len + window_left < kv_len + q_idx); s_frag[fx][fz][reg_id] = (out_of_boundary || (mask_mode == MaskMode::kCustom && q_idx < qo_len && @@ -972,8 +967,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, - const uint32_t kv_stride_h, const float logits_soft_cap, float sm_scale, - const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + const uint32_t kv_stride_h, const int32_t maybe_window_left, const float logits_soft_cap, + float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= @@ -991,6 +986,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((chunk_idx + 1) * chunk_size, kv_len) : kv_len; auto block = cg::this_thread_block(); + const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); @@ -1067,6 +1063,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC : chunk_end - chunk_start, 16 * num_warps_z * num_frags_z); + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta, + qo_len + window_left + chunk_start), + (16 * num_warps_z * num_frags_z)); + const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, @@ -1126,14 +1127,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, custom_mask, s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, custom_mask, s_frag); } else { - if (iter >= mask_iteration) { + if (iter >= mask_iteration || iter < window_iteration) { mask_s( qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, nullptr, s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); } } @@ -1214,8 +1215,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg float* __restrict__ lse, bool* __restrict__ block_valid_mask, IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, - const uint32_t kv_stride_h, const float logits_soft_cap, float sm_scale, - const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { + const uint32_t kv_stride_h, const int32_t maybe_window_left, const float logits_soft_cap, + float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= @@ -1235,6 +1236,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; + const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; @@ -1325,6 +1327,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg : chunk_end - chunk_start), 16 * num_warps_z * num_frags_z); + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta, + qo_len + window_left + chunk_start), + (16 * num_warps_z * num_frags_z)); + const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, @@ -1392,14 +1399,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, custom_mask + qk_indptr[request_idx], s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, custom_mask + qk_indptr[request_idx], + s_frag); } else { - if (iter >= mask_iteration) { + if (iter >= mask_iteration || iter < window_iteration) { mask_s( qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, nullptr, s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); } } @@ -1483,8 +1491,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage IdType* __restrict__ q_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, float* __restrict__ lse, bool* __restrict__ block_valid_mask, IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, - const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, - const float log2_rope_rcp_theta) { + int32_t maybe_window_left, const float logits_soft_cap, float sm_scale, + const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= @@ -1509,6 +1517,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage 1) * paged_kv.page_size + paged_kv.last_page_len[request_idx] : 0; + const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; @@ -1632,6 +1641,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage : chunk_end - chunk_start), 16 * num_warps_z * num_frags_z); + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta, + qo_len + window_left + chunk_start), + (16 * num_warps_z * num_frags_z)); + const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal ? min(chunk_end - chunk_start, @@ -1684,14 +1698,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, custom_mask + qk_indptr[request_idx], s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, custom_mask + qk_indptr[request_idx], + s_frag); } else { - if (iter >= mask_iteration) { + if (iter >= mask_iteration || iter < window_iteration) { mask_s( qo_packed_idx_base, chunk_start + (iter * num_warps_z + get_warp_idx_z()) * num_frags_z * 16, - qo_len, kv_len, chunk_end, group_size, nullptr, s_frag); + qo_len, kv_len, window_left, chunk_end, group_size, nullptr, s_frag); } } @@ -1767,7 +1782,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); @@ -1877,6 +1892,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( (void*)&q_stride_h, (void*)&kv_stride_n, (void*)&kv_stride_h, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -1903,6 +1919,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( (void*)&q_stride_h, (void*)&kv_stride_n, (void*)&kv_stride_h, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -1930,8 +1947,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { + uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_frags_x = get_num_frags_x(); @@ -2015,6 +2032,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&q_stride_h, (void*)&kv_stride_n, (void*)&kv_stride_h, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -2051,6 +2069,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&q_stride_h, (void*)&kv_stride_n, (void*)&kv_stride_h, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -2075,8 +2094,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( uint8_t* custom_mask, IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { + uint32_t padded_batch_size, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_frags_x = get_num_frags_x(); @@ -2157,6 +2176,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&block_valid_mask, (void*)&kv_chunk_size_ptr, (void*)&group_size_fastdiv, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -2186,6 +2206,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&block_valid_mask, (void*)&kv_chunk_size_ptr, (void*)&group_size_fastdiv, + (void*)&window_left, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 250ea201..aa9bb2a4 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -29,12 +29,10 @@ namespace flashinfer { template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t seq_len, - QKVLayout kv_layout, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream); +cudaError_t SingleDecodeWithKVCacheDispatched( + DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t seq_len, QKVLayout kv_layout, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream); template paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { + uint32_t num_qo_heads, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); @@ -81,8 +80,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, - handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), num_qo_heads, logits_soft_cap, - sm_scale, rope_scale, rope_theta, stream); + handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), num_qo_heads, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); } } // namespace flashinfer diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 4d699caf..93b40dac 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -35,7 +35,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched( DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream); template paged_kv, uint8_t* custom_mask, IdType* qk_indptr, - DTypeOut* o, float* lse, uint32_t num_qo_heads, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, cudaStream_t stream) { + DTypeOut* o, float* lse, uint32_t num_qo_heads, int32_t window_left, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { DTypeOut* tmp_v = nullptr; float* tmp_s = nullptr; IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, @@ -105,8 +106,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, - kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, logits_soft_cap, - sm_scale, rope_scale, rope_theta, stream); + kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } @@ -119,8 +120,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, - uint32_t kv_stride_h, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - cudaStream_t stream) { + uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, cudaStream_t stream) { DTypeOut* tmp_v = nullptr; float* tmp_s = nullptr; IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr, @@ -157,7 +158,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, - logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); + window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, stream); }); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index f12923a2..c604fc2e 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -246,6 +246,13 @@ inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = std::cout << std::endl; } +/*! + * \brief Return x - y if x > y, otherwise return 0. + */ +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t y) { + return (x > y) ? x - y : 0U; +} + } // namespace flashinfer #endif // FLASHINFER_UTILS_CUH_ diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 5b936246..130f4abb 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -108,8 +108,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, std::optional paged_kv_cache, std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); bool paged_kv_defined = paged_kv_cache.has_value(); if (paged_kv_defined) { @@ -216,7 +216,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); @@ -249,7 +249,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index b3229156..a060940a 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -62,8 +62,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(q); CHECK_INPUT(qo_indptr); @@ -179,7 +179,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error code ", @@ -204,8 +205,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(q); CHECK_INPUT(qo_indptr); @@ -331,7 +332,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu static_cast(qk_indptr.data_ptr()), static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error code ", @@ -389,8 +391,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(qo_indptr); CHECK_CUDA(q); CHECK_CUDA(k); @@ -465,7 +467,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, - kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta, + kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, + rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", @@ -488,8 +491,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(qo_indptr); CHECK_CUDA(q); CHECK_CUDA(k); @@ -572,7 +575,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, - kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta, + kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 0f526c26..73fa511e 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -23,19 +23,21 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta); + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta); std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); + bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, std::optional paged_kv_cache, @@ -113,9 +115,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + unsigned int pos_encoding_mode, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta, bool return_lse); BatchDecodeWithPagedKVCachePyTorchWrapper( std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) : handler_(handler_ptr), kv_layout_(kv_layout) {} @@ -146,15 +148,16 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); + int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); std::vector ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + int window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(enable_cuda_graph)) {} @@ -175,13 +178,13 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); + int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); std::vector ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta, bool return_lse); + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), handler_(std::make_shared(enable_cuda_graph)) {} diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index acb05cd4..10013f9c 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -22,8 +22,9 @@ using namespace flashinfer; torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int pos_encoding_mode, - unsigned int layout, float logits_soft_cap, - float sm_scale, float rope_scale, float rope_theta) { + unsigned int layout, int window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -71,7 +72,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, - kv_layout, logits_soft_cap, sm_scale, rope_scale, rope_theta, + kv_layout, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + @@ -93,7 +94,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, - kv_layout, logits_soft_cap, sm_scale, rope_scale, rope_theta, + kv_layout, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 9dfc740a..2e96f720 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -23,7 +23,8 @@ using namespace flashinfer; std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_CUDA(q); CHECK_CUDA(k); CHECK_CUDA(v); @@ -91,8 +92,8 @@ std::vector single_prefill_with_kv_cache( static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale, - rope_theta, torch_current_stream); + kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, + rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -114,8 +115,8 @@ std::vector single_prefill_with_kv_cache( std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor packed_custom_mask, torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, - bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + bool allow_fp16_qk_reduction, int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { CHECK_CUDA(q); CHECK_CUDA(k); CHECK_CUDA(v); @@ -184,8 +185,8 @@ std::vector single_prefill_with_kv_cache_custom_mask( static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale, - rope_theta, torch_current_stream); + kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, + rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 2b4d4b36..504d82e4 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -68,6 +68,7 @@ def single_decode_with_kv_cache( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -102,6 +103,9 @@ def single_decode_with_kv_cache( The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -177,6 +181,7 @@ def single_decode_with_kv_cache( TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, False, # allow_fp16_qk_reduction + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -191,6 +196,7 @@ def single_decode_with_kv_cache( tmp, PosEncodingMode[pos_encoding_mode].value, TensorLayout[kv_layout].value, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -546,6 +552,7 @@ def forward( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -581,6 +588,9 @@ def forward( The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -626,6 +636,7 @@ def forward( False, # causal PosEncodingMode[pos_encoding_mode].value, False, # allow_fp16_qk_reduction + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -640,6 +651,7 @@ def forward( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -658,6 +670,7 @@ def forward_return_lse( q_scale: Optional[float] = None, k_scale: Optional[float] = None, v_scale: Optional[float] = None, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -694,6 +707,9 @@ def forward_return_lse( The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -745,6 +761,7 @@ def forward_return_lse( False, # causal PosEncodingMode[pos_encoding_mode].value, False, # allow_fp16_qk_reduction + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -759,6 +776,7 @@ def forward_return_lse( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, + window_left, logits_soft_cap, sm_scale, rope_scale, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 5e0e0aa8..d914c90d 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -66,6 +66,7 @@ def single_prefill_with_kv_cache( pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, + window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -109,6 +110,9 @@ def single_prefill_with_kv_cache( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -192,6 +196,7 @@ def single_prefill_with_kv_cache( TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -208,6 +213,7 @@ def single_prefill_with_kv_cache( TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -226,6 +232,7 @@ def single_prefill_with_kv_cache_return_lse( kv_layout: str = "NHD", pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -270,6 +277,9 @@ def single_prefill_with_kv_cache_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -371,6 +381,7 @@ def single_prefill_with_kv_cache_return_lse( TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -387,6 +398,7 @@ def single_prefill_with_kv_cache_return_lse( TensorLayout[kv_layout].value, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -806,6 +818,7 @@ def forward( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -842,6 +855,9 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -883,6 +899,7 @@ def forward( causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -901,6 +918,7 @@ def forward( self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -915,6 +933,7 @@ def forward_return_lse( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -949,6 +968,9 @@ def forward_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -993,6 +1015,7 @@ def forward_return_lse( causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -1011,6 +1034,7 @@ def forward_return_lse( self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -1356,6 +1380,7 @@ def forward( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -1382,6 +1407,9 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -1429,6 +1457,7 @@ def forward( causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -1446,6 +1475,7 @@ def forward( self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -1461,6 +1491,7 @@ def forward_return_lse( causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + window_left: int = -1, logits_soft_cap: Optional[float] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -1487,6 +1518,9 @@ def forward_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + window_left : int + The left (inclusive) window size for the attention window, when set to ``-1``, the window + size will be set to the full length of the sequence. Defaults to ``-1``. logits_soft_cap : Optional[float] The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not provided, will be set to ``0``. If greater than 0, the logits will be capped according to @@ -1536,6 +1570,7 @@ def forward_return_lse( causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, @@ -1553,6 +1588,7 @@ def forward_return_lse( self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, + window_left, logits_soft_cap, sm_scale, rope_scale, diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index 296a9fbc..6b98adcd 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -46,7 +46,7 @@ def get_cu_file_str( kv_partition_info_t<{idtype}> kv_partition_info, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, - float logits_soft_cap, float sm_scale, float rope_scale, + int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); }} diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index c844b198..a739f777 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -47,8 +47,8 @@ def get_cu_file_str( paged_kv_t paged_kv, uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t max_num_rows, - uint32_t num_qo_heads, uint32_t padded_batch_size, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + uint32_t num_qo_heads, uint32_t padded_batch_size, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( logits_hook=logits_hook_literal[int(logits_hook)], warp_layout=warp_layout_literal[warp_layout], diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index cbb3f2ac..794d166a 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -47,7 +47,8 @@ def get_cu_file_str( {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( warp_layout=warp_layout_literal[warp_layout], diff --git a/python/generate_single_decode_inst.py b/python/generate_single_decode_inst.py index 52af3f00..fc57f36c 100644 --- a/python/generate_single_decode_inst.py +++ b/python/generate_single_decode_inst.py @@ -39,7 +39,7 @@ def get_cu_file_str( template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, - QKVLayout kv_layout, float logits_soft_cap, float sm_scale, float rope_scale, + QKVLayout kv_layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); }} diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index cf1702ab..9b8103ce 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -42,7 +42,7 @@ def get_cu_file_str( template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, + uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); }} diff --git a/python/tests/test_sliding_window.py b/python/tests/test_sliding_window.py new file mode 100644 index 00000000..3fc1b1bb --- /dev/null +++ b/python/tests/test_sliding_window.py @@ -0,0 +1,355 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy +import pytest +import torch +import flashinfer + + +@pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) +@pytest.mark.parametrize("window_left", [3, 13, 23, 43]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_single_decode_sliding_window( + seq_len, window_left, num_kv_heads, num_qo_heads, head_dim +): + q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") + k = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + + k_sliced = k[-(window_left + 1) :] + v_sliced = v[-(window_left + 1) :] + + o_ref = flashinfer.single_decode_with_kv_cache(q, k_sliced, v_sliced) + o = flashinfer.single_decode_with_kv_cache(q, k, v, window_left=window_left) + + numpy.testing.assert_allclose(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 3, 13, 32]) +@pytest.mark.parametrize("kv_len", [1, 3, 99, 199, 1999]) +@pytest.mark.parametrize("window_left", [33, 533]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("page_size", [1, 16]) +def test_batch_decode_sliding_window( + batch_size, kv_len, window_left, num_kv_heads, num_qo_heads, head_dim, page_size +): + q = torch.randn( + batch_size, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + k_data = torch.randn( + total_num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + v_data = torch.randn( + total_num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + ) + o = wrapper.forward( + q, + (k_data, v_data), + window_left=window_left, + ) + + for i in range(batch_size): + qi = q[i] + ki = torch.cat( + [ + k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( + -1, num_kv_heads, head_dim + ), + k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], + ], + dim=0, + ) + vi = torch.cat( + [ + v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( + -1, num_kv_heads, head_dim + ), + v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], + ], + dim=0, + ) + o_ref_i = flashinfer.single_decode_with_kv_cache( + qi, + ki, + vi, + window_left=window_left, + ) + o_i_np = o[i].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) +@pytest.mark.parametrize("window_left", [3, 13, 23, 43]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_single_decode_prefill_sliding_window_match( + seq_len, window_left, num_kv_heads, num_qo_heads, head_dim +): + q = torch.randn(1, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") + k = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + o = flashinfer.single_prefill_with_kv_cache( + q, k, v, window_left=window_left, causal=True + ) + o_decoded = flashinfer.single_decode_with_kv_cache( + q[0], k, v, window_left=window_left + ) + numpy.testing.assert_allclose(o.cpu()[0], o_decoded.cpu(), rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("seq_len", [99, 199, 1999]) +@pytest.mark.parametrize("window_left", [43, 233]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_single_prefill_sliding_window( + seq_len, window_left, num_kv_heads, num_qo_heads, head_dim +): + q = torch.randn( + seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + k = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + v = torch.randn( + seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" + ) + + row_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[:, None] + col_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[None, :] + mask = (row_idx >= col_idx) & (row_idx - window_left <= col_idx) + + o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) + o = flashinfer.single_prefill_with_kv_cache( + q, k, v, window_left=window_left, causal=True + ) + numpy.testing.assert_allclose(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 397]) +@pytest.mark.parametrize("qo_len", [37, 47]) +@pytest.mark.parametrize("window_left", [13, 33]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("page_size", [1, 16]) +def test_batch_paged_prefill_sliding_window( + batch_size, + kv_len, + qo_len, + window_left, + num_kv_heads, + num_qo_heads, + head_dim, + page_size, +): + q = torch.randn( + batch_size * qo_len, + num_qo_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + k_data = torch.randn( + total_num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + v_data = torch.randn( + total_num_pages, + page_size, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + wrapper.begin_forward( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + o = wrapper.forward( + q, + (k_data, v_data), + window_left=window_left, + ) + + for i in range(batch_size): + qi = q[q_indptr[i] : q_indptr[i + 1]] + ki = torch.cat( + [ + k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( + -1, num_kv_heads, head_dim + ), + k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], + ], + dim=0, + ) + vi = torch.cat( + [ + v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( + -1, num_kv_heads, head_dim + ), + v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], + ], + dim=0, + ) + o_ref_i = flashinfer.single_prefill_with_kv_cache( + qi, + ki, + vi, + window_left=window_left, + causal=True, + ) + o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 397]) +@pytest.mark.parametrize("qo_len", [37, 47]) +@pytest.mark.parametrize("window_left", [13, 33]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_batch_ragged_prefill_sliding_window( + batch_size, kv_len, qo_len, window_left, num_kv_heads, num_qo_heads, head_dim +): + q = torch.randn( + batch_size * qo_len, + num_qo_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + k = torch.randn( + batch_size * kv_len, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + v = torch.randn( + batch_size * kv_len, + num_kv_heads, + head_dim, + dtype=torch.float16, + device="cuda:0", + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, "NHD") + wrapper.begin_forward( + q_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + o = wrapper.forward( + q, + k, + v, + window_left=window_left, + ) + + for i in range(batch_size): + qi = q[q_indptr[i] : q_indptr[i + 1]] + ki = k[kv_indptr[i] : kv_indptr[i + 1]] + vi = v[kv_indptr[i] : kv_indptr[i + 1]] + o_ref_i = flashinfer.single_prefill_with_kv_cache( + qi, + ki, + vi, + window_left=window_left, + causal=True, + ) + o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_single_prefill_sliding_window(13, 20, 1, 4, 128) + test_batch_paged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128, 1) + test_batch_ragged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128) diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 2c6bbc3b..61b9e746 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -43,6 +43,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( MaskMode::kCustom>( q, k, v, custom_mask, o, tmp, lse, num_qo_heads, num_kv_heads, qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, + /*window_left=*/-1, /*logits_soft_cap*/ 0.f, sm_scale, rope_scale, rope_theta, stream); })})}); return cudaSuccess; @@ -98,6 +99,7 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu ALLOW_FP16_QK_REDUCTION, MASK_MODE>( q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, num_kv_heads, qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; @@ -129,6 +131,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, num_qo_heads, num_kv_heads, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; @@ -158,8 +161,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>( handler, q, qo_indptr, q_offset, paged_kv, /*custom_mask=*/nullptr, - /*qk_indptr=*/nullptr, o, lse, num_qo_heads, /*logits_soft_cap=*/0.f, sm_scale, - rope_scale, rope_theta, stream); + /*qk_indptr=*/nullptr, o, lse, num_qo_heads, /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; } @@ -184,6 +187,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { SingleDecodeWithKVCacheDispatched( q, k, v, o, tmp, num_qo_heads, num_kv_heads, seq_len, kv_layout, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })}); return cudaSuccess; @@ -215,6 +219,7 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( IdType>( q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, num_qo_heads, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })}); @@ -265,6 +270,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( PAGE_STORAGE, HEAD_DIM, LogitsPostHook::kNone, POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( handler, q, q_offset, paged_kv, o, lse, num_qo_heads, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })}); return cudaSuccess;