From 936d500126aa4a26e0d9b46123d1420d3b5e3234 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Wed, 7 May 2025 14:57:58 -0400 Subject: [PATCH 01/12] Sliding window changes, with correctness/performance bug --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 39 ++++--- src/runtime/relax_vm/attn_utils.h | 3 +- src/runtime/relax_vm/paged_kv_cache.cc | 112 ++++++++++++++++--- 3 files changed, 123 insertions(+), 31 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 1d06bf2f3595..314c1b8a21f6 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -20,7 +20,7 @@ # pylint: disable=too-many-statements,too-many-lines,too-many-arguments,invalid-name import enum import math -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import tvm from tvm import relax as rx @@ -86,6 +86,7 @@ class AttnKind(enum.IntEnum): MHA = 0 MLA = 1 + MHA_SLIDING = 3 class RopeMode(enum.IntEnum): @@ -301,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -377,8 +378,8 @@ def __init__( # pylint: disable=too-many-locals dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim, - v_head_dim=v_head_dim if attn_kind == "mha" else mla_original_v_head_dim, + qk_head_dim=qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim, + v_head_dim=v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim, target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, ) @@ -391,7 +392,7 @@ def __init__( # pylint: disable=too-many-locals v_head_dim=v_head_dim, target=target, ) - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [] ) flashinfer_mla_mods = ( @@ -420,7 +421,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) @@ -430,6 +431,11 @@ def __init__( # pylint: disable=too-many-locals if attn_kind == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) + + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] args = [ rx.ShapeExpr( [ @@ -482,7 +488,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, - attn_kind: Literal["mha", "mla"], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -553,7 +559,10 @@ def __init__( # pylint: disable=too-many-locals target : Target The target to build the model to. """ - + if isinstance(attn_kind, List): + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + else: + attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] bb = rx.BlockBuilder.current() args = [ rx.ShapeExpr( @@ -570,9 +579,7 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(num_key_value_heads), rx.PrimValue(qk_head_dim), rx.PrimValue(v_head_dim), - rx.ShapeExpr( - [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] - ), + rx.ShapeExpr(attn_kind), rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), @@ -614,9 +621,9 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off - ragged_qk_head_dim = qk_head_dim if attn_kind == "mha" else mla_original_qk_head_dim - ragged_v_head_dim = v_head_dim if attn_kind == "mha" else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), @@ -626,7 +633,7 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if attn_kind == "mha" + if (attn_kind == "mha" or isinstance(attn_kind, List)) else [rx.Tuple([]) for _ in range(6)] ) mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) @@ -641,7 +648,7 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ] diff --git a/src/runtime/relax_vm/attn_utils.h b/src/runtime/relax_vm/attn_utils.h index 8138aa7bbdf6..f2a63aeb9044 100644 --- a/src/runtime/relax_vm/attn_utils.h +++ b/src/runtime/relax_vm/attn_utils.h @@ -62,13 +62,14 @@ enum class AttnKind : int { kMHA = 0, kMLA = 1, kLinearAttn = 2, + kMHASliding = 3, }; /*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */ inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, int64_t v_head_dim) { - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; } else if (attn_kind == AttnKind::kMLA) { diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 496891f2172b..921cb99a46ee 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -195,6 +195,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_host_; std::vector page_indptr_on_depths_host_; std::vector page_indices_on_depths_host_; + std::vector page_indptr_sliding_window_on_depths_host_; + std::vector page_indices_sliding_window_on_depths_host_; std::vector last_page_len_on_depths_host_; std::vector sliding_window_offset_on_depths_host_; std::vector sink_size_on_depths_host_; @@ -236,6 +238,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector qo_indptr_on_depths_view_; std::vector page_indptr_on_depths_view_; std::vector page_indices_on_depths_view_; + std::vector page_indptr_sliding_window_on_depths_view_; + std::vector page_indices_sliding_window_on_depths_view_; std::vector length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; std::vector tree_attn_mask_view_; @@ -297,7 +301,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { v_head_dim_(v_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), - support_sliding_window_(support_sliding_window), + support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end() ? false : support_sliding_window), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -373,6 +377,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); page_indices_on_depths_host_.push_back( HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); + page_indptr_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, preferred_host_device)); + page_indices_sliding_window_on_depths_host_.push_back( + HostMemoryVector(num_total_pages, dtype_aux_, preferred_host_device)); last_page_len_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); sliding_window_offset_on_depths_host_.push_back( @@ -423,6 +431,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); + page_indptr_sliding_window_on_depths_view_.push_back(NDArray()); + page_indices_sliding_window_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); tree_attn_mask_view_.push_back(NDArray()); @@ -711,7 +721,30 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { - CHECK(support_sliding_window_) << "The KV cache does not support sliding window."; + // If per layer sliding window exists, enable sliding window for sequence + CHECK(support_sliding_window_ || std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) << "The KV cache does not support sliding window."; + // for (AttnKind attn_kind : attn_kinds_) { + // if (attn_kind == AttnKind::kMHASliding) { + // LOG(INFO) << "Found sliding"; + // } else if (attn_kind == AttnKind::kMHA) { + // LOG(INFO) << "Found non-sliding"; + // } else { + // LOG(INFO) << "Found other"; + // } + // } + // if (support_sliding_window_) { + // LOG(INFO) << "Sldiing window supported"; + // } else { + // LOG(INFO) << "Sldiing window not supported"; + // } + // if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) { + // LOG(INFO) << "Sliding layer found"; + // } else { + // LOG(INFO) << "Sliding layer not found"; + // } + // CHECK(!support_sliding_window_) << "The KV cache does not support sliding window."; + LOG(INFO) << "Enabling sliding window"; + LOG(INFO) << sliding_window_size; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) @@ -933,6 +966,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; + HostMemoryVector& page_indptr_sliding_window_h = page_indptr_sliding_window_on_depths_host_[d]; + HostMemoryVector& page_indices_sliding_window_h = page_indices_sliding_window_on_depths_host_[d]; HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; @@ -940,17 +975,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); + page_indptr_sliding_window_h.clear(); + page_indices_sliding_window_h.clear(); last_page_len_h.clear(); sliding_window_offset_h.clear(); sink_size_h.clear(); k_rope_pos_offset_h.clear(); qo_indptr_h.push_back(0); page_indptr_h.push_back(0); + page_indptr_sliding_window_h.push_back(0); for (int i = 0; i < static_cast(chunked_block_ids_arr[d].size()); ++i) { const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i]; qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length); if (block_id == -1) { page_indptr_h.push_back(page_indptr_h.back()); + page_indptr_sliding_window_h.push_back(page_indptr_sliding_window_h.back()); last_page_len_h.push_back(0); sliding_window_offset_h.push_back(0); sink_size_h.push_back(0); @@ -962,7 +1001,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); for (int32_t page_id : block.page_ids) { page_indices_h.push_back(page_id); + // Do the same for page_indices_sliding_window + } + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), sequences[d]->sliding_window_size / page_size_)); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); } + // set up the page indices properly by choosing the last (sliding_window_size / + // page_size_) pages (at most) last_page_len_h.push_back( block.seq_length == 0 ? 0 @@ -991,7 +1039,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += block.seq_length; last_block_id = id; } + // Also add sliding window here? page_indptr_h.push_back(page_indptr_h.back() + num_pages); + page_indptr_sliding_window_h.push_back( + page_indptr_sliding_window_h.back() + + std::min(static_cast(num_pages), sequences[d]->sliding_window_size / page_size_)); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indices_sliding_window_h.push_back(page_indices_h[i]); + } const Block& last_block = global_block_pool_[last_block_id]; last_page_len_h.push_back(total_seq_length == 0 ? 0 @@ -1187,7 +1242,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA); + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || attn_kinds_[layer_id] == AttnKind::kMHASliding); // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1795,7 +1850,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { (block.sliding_window_offset + length_to_slide) % page_size_; // - Free the pages that are fully slidden. - while (page_idx_after_sliding > num_sink_pages) { + while (support_sliding_window_ && page_idx_after_sliding > num_sink_pages) { if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) { free_page_ids_.push_back(block.page_ids[num_sink_pages]); } @@ -1849,7 +1904,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t page_idx = cur_npage; page_idx < tgt_npage; ++page_idx) { // When sliding window is enabled for the seq, we can "borrow temporary pages (-1)", // since the pages need to be slidden out might not have been released. - if (free_page_ids_.empty() && seq->sliding_window_size != -1) { + if (free_page_ids_.empty() && seq->sliding_window_size != -1 && support_sliding_window_) { block.page_ids.push_back(kPagedKVCacheTempPageId); } else { block.page_ids.push_back(GetFreePage()); @@ -1860,10 +1915,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // ==================== Slide ==================== // Slide the sequences so that the pages exceed the sliding window are released. SlideWindowForSequence(seq); - for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { - if (block.page_ids[i] == kPagedKVCacheTempPageId) { - // Re-allocate the temporary pages after sliding window release. - block.page_ids[i] = GetFreePage(); + if (support_sliding_window_) { + for (int i = 0; i < static_cast(block.page_ids.size()); ++i) { + if (block.page_ids[i] == kPagedKVCacheTempPageId) { + // Re-allocate the temporary pages after sliding window release. + block.page_ids[i] = GetFreePage(); + } } } @@ -2058,19 +2115,30 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_lse = temp_attn_lse_view_; } + // If layer is sliding window, use sliding window index pointer/indices + NDArray page_indptr; + NDArray page_indices; + if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { + page_indptr = page_indptr_sliding_window_on_depths_view_[d]; + page_indices = page_indices_sliding_window_on_depths_view_[d]; + } else { + page_indptr = page_indptr_on_depths_view_[d]; + page_indices = page_indices_on_depths_view_[d]; + } + if (append_before_attn_ && !is_chain_on_depths_[d]) { ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + page_indptr, page_indices, length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); - f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], + f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, + page_indices, length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); @@ -2078,7 +2146,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Use prefill kernel for depth d ICHECK_NOTNULL(f_prefill); f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + page_indptr, page_indices, length_info_on_depths_view_[d], q_rope_position_map_view_, k_rope_pos_offset_view_[d], /*causal=*/false, /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, @@ -2193,7 +2261,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indices_on_depths_view_[d] = aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d], d); } - // 5. length_info_on_depths + + // If per layer sliding window exists, must copy additional vectors + if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) { + // 5. page_indptr_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); + page_indptr_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_sliding_window_on_depths_host_[d], d); + } + // 6. page_indices_sliding_window_on_depths + for (int d = 0; d < num_depths_; ++d) { + ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), page_indptr_sliding_window_on_depths_host_[d].back()); + page_indices_sliding_window_on_depths_view_[d] = + aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_sliding_window_on_depths_host_[d], d); + } + } + // 7. length_info_on_depths // last_page_len_on_depths_host_; // sliding_window_offset_on_depths_host_; // sink_size_on_depths_host_; From f8942712a890af37b75f79a6311cc6c479ba5995 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 19 May 2025 01:16:39 -0400 Subject: [PATCH 02/12] Various fixes, code cleanup, but still nvidia memory error --- src/runtime/relax_vm/paged_kv_cache.cc | 69 +++++++++++++++++++------- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 921cb99a46ee..40e12af4d0a1 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -98,6 +98,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const int64_t prefill_chunk_size_; /*! \brief A boolean flag indicating if the KV cache supports sliding window. */ const bool support_sliding_window_; + /*! \brief A boolean flag indicating if the KV cache has per layer sliding window. */ + const bool support_layer_sliding_window_; /*! \brief The attention kinds for each layer. */ const std::vector attn_kinds_; @@ -241,6 +243,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector page_indptr_sliding_window_on_depths_view_; std::vector page_indices_sliding_window_on_depths_view_; std::vector length_info_on_depths_view_; + std::vector layer_sliding_window_length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; @@ -302,6 +305,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end() ? false : support_sliding_window), + support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end()), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -434,6 +438,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr_sliding_window_on_depths_view_.push_back(NDArray()); page_indices_sliding_window_on_depths_view_.push_back(NDArray()); length_info_on_depths_view_.push_back(NDArray()); + layer_sliding_window_length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); tree_attn_mask_view_.push_back(NDArray()); tree_attn_mn_indptr_view_.push_back(NDArray()); @@ -722,7 +727,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { // If per layer sliding window exists, enable sliding window for sequence - CHECK(support_sliding_window_ || std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) << "The KV cache does not support sliding window."; + CHECK(support_sliding_window_ || support_layer_sliding_window_) << "The KV cache does not support sliding window."; // for (AttnKind attn_kind : attn_kinds_) { // if (attn_kind == AttnKind::kMHASliding) { // LOG(INFO) << "Found sliding"; @@ -743,8 +748,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // LOG(INFO) << "Sliding layer not found"; // } // CHECK(!support_sliding_window_) << "The KV cache does not support sliding window."; - LOG(INFO) << "Enabling sliding window"; - LOG(INFO) << sliding_window_size; + // LOG(INFO) << "Enabling sliding window"; + // LOG(INFO) << sliding_window_size; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) @@ -998,15 +1003,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (d < kPagedKVCacheMaxBlockDepth - 1) { // Blocks not at maximum depth const Block& block = global_block_pool_[block_id]; + LOG(INFO) << "CHECKING BASE INDICES"; + LOG(INFO) << page_indptr_h.back() + block.page_ids.size(); + LOG(INFO) << "CHECKING PAGE IDS"; page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); for (int32_t page_id : block.page_ids) { + LOG(INFO) << page_id; page_indices_h.push_back(page_id); // Do the same for page_indices_sliding_window } + LOG(INFO) << "CHECKING SLIDING WINDOW INDPTR"; + LOG(INFO) << page_indptr_sliding_window_h.back(); + LOG(INFO) << std::min(static_cast(block.page_ids.size()), sequences[d]->sliding_window_size / page_size_); + LOG(INFO) << sequences[d]->sliding_window_size; page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), sequences[d]->sliding_window_size / page_size_)); + LOG(INFO) << "CHECKING INDICES"; for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + LOG(INFO) << page_indices_h[i]; page_indices_sliding_window_h.push_back(page_indices_h[i]); } // set up the page indices properly by choosing the last (sliding_window_size / @@ -1849,18 +1864,24 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t page_start_offset_after_sliding = (block.sliding_window_offset + length_to_slide) % page_size_; + if (!support_sliding_window_) { + block.sliding_window_offset = + page_idx_after_sliding * page_size_ + page_start_offset_after_sliding; + return; + } + // - Free the pages that are fully slidden. - while (support_sliding_window_ && page_idx_after_sliding > num_sink_pages) { - if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) { - free_page_ids_.push_back(block.page_ids[num_sink_pages]); - } - block.page_ids.erase(block.page_ids.begin() + num_sink_pages); + while (page_idx_after_sliding > num_sink_pages) { + if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) { + free_page_ids_.push_back(block.page_ids[num_sink_pages]); + } + block.page_ids.erase(block.page_ids.begin() + num_sink_pages); --page_idx_after_sliding; } // - The first sliding page after sliding is either the last sink page, // or the page next to the last sink page. ICHECK(page_idx_after_sliding == num_sink_pages - 1 || - page_idx_after_sliding == num_sink_pages); + page_idx_after_sliding == num_sink_pages); // - Update the length of the sequence and the block. seq->seq_length = seq->sliding_window_size; @@ -1870,9 +1891,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_GE(block.seq_length, block.sink_length); ICHECK_GE(block.sliding_window_offset, block.sink_length); ICHECK_EQ( - (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / - page_size_, - block.page_ids.size()); + (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / + page_size_, + block.page_ids.size()); } /*! @@ -1978,7 +1999,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + CHECK(!support_sliding_window_ || !support_layer_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { if (f_attention_decode_ != nullptr && f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) { @@ -2096,9 +2117,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, double sm_scale, bool is_first_kernel) { std::unique_ptr& f_prefill = - !support_sliding_window_ ? f_attention_prefill_ : f_attention_prefill_sliding_window_; + (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_prefill_ : f_attention_prefill_sliding_window_; std::unique_ptr& f_decode = - !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_; + (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_decode_ : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool cross_attn_computed = false; @@ -2118,12 +2139,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // If layer is sliding window, use sliding window index pointer/indices NDArray page_indptr; NDArray page_indices; + NDArray length_info; if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { page_indptr = page_indptr_sliding_window_on_depths_view_[d]; page_indices = page_indices_sliding_window_on_depths_view_[d]; + length_info = layer_sliding_window_length_info_on_depths_view_[d]; } else { page_indptr = page_indptr_on_depths_view_[d]; page_indices = page_indices_on_depths_view_[d]; + length_info = length_info_on_depths_view_[d]; } if (append_before_attn_ && !is_chain_on_depths_[d]) { @@ -2131,14 +2155,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_, + length_info, k_rope_pos_offset_view_[d], q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, - page_indices, length_info_on_depths_view_[d], + page_indices, length_info, k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); @@ -2147,7 +2171,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_NOTNULL(f_prefill); f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, - length_info_on_depths_view_[d], q_rope_position_map_view_, + length_info, q_rope_position_map_view_, k_rope_pos_offset_view_[d], /*causal=*/false, /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); @@ -2263,7 +2287,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // If per layer sliding window exists, must copy additional vectors - if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) { + if (support_layer_sliding_window_) { // 5. page_indptr_sliding_window_on_depths for (int d = 0; d < num_depths_; ++d) { ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); @@ -2296,6 +2320,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], &sink_size_on_depths_host_[d], d); } + + if (support_layer_sliding_window_) { + layer_sliding_window_length_info_on_depths_view_[d] = aux_data_manager_->CopyLengthInfoOnDepthAsync( + &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], + &sink_size_on_depths_host_[d], d); + } + } // 6. k_rope_pos_offset_on_depths for (int d = 0; d < num_depths_; ++d) { From 6a07656d2792d40da8f208b7630b2983e0ef705e Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 2 Jun 2025 01:15:51 -0400 Subject: [PATCH 03/12] Fix correctness bug --- src/runtime/relax_vm/paged_kv_cache.cc | 92 +++++++++++++++++--------- 1 file changed, 62 insertions(+), 30 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 40e12af4d0a1..02bb13473449 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -203,6 +203,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector sliding_window_offset_on_depths_host_; std::vector sink_size_on_depths_host_; std::vector k_rope_pos_offset_on_depths_host_; + std::vector k_rope_pos_offset_sliding_window_on_depths_host_; HostMemoryVector k_ragged_rope_pos_offset_host_; HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; @@ -245,6 +246,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector length_info_on_depths_view_; std::vector layer_sliding_window_length_info_on_depths_view_; std::vector k_rope_pos_offset_view_; + std::vector k_rope_pos_offset_sliding_window_view_; std::vector tree_attn_mask_view_; std::vector tree_attn_mn_indptr_view_; @@ -393,6 +395,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); k_rope_pos_offset_on_depths_host_.push_back( HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); + k_rope_pos_offset_sliding_window_on_depths_host_.push_back( + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mask_host_.push_back(HostMemoryVector(kTreeAttnMaxTreeSize * 2 * reserved_num_seqs, dtype_aux_, preferred_host_device)); tree_attn_mn_indptr_host_.push_back( @@ -440,6 +444,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { length_info_on_depths_view_.push_back(NDArray()); layer_sliding_window_length_info_on_depths_view_.push_back(NDArray()); k_rope_pos_offset_view_.push_back(NDArray()); + k_rope_pos_offset_sliding_window_view_.push_back(NDArray()); tree_attn_mask_view_.push_back(NDArray()); tree_attn_mn_indptr_view_.push_back(NDArray()); is_chain_on_depths_.push_back(true); @@ -977,6 +982,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_sliding_window_h = k_rope_pos_offset_sliding_window_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); @@ -986,6 +992,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sliding_window_offset_h.clear(); sink_size_h.clear(); k_rope_pos_offset_h.clear(); + k_rope_pos_offset_sliding_window_h.clear(); qo_indptr_h.push_back(0); page_indptr_h.push_back(0); page_indptr_sliding_window_h.push_back(0); @@ -999,29 +1006,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sliding_window_offset_h.push_back(0); sink_size_h.push_back(0); k_rope_pos_offset_h.push_back(0); + k_rope_pos_offset_sliding_window_h.push_back(0); } else { if (d < kPagedKVCacheMaxBlockDepth - 1) { // Blocks not at maximum depth const Block& block = global_block_pool_[block_id]; - LOG(INFO) << "CHECKING BASE INDICES"; - LOG(INFO) << page_indptr_h.back() + block.page_ids.size(); - LOG(INFO) << "CHECKING PAGE IDS"; page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size()); for (int32_t page_id : block.page_ids) { - LOG(INFO) << page_id; page_indices_h.push_back(page_id); // Do the same for page_indices_sliding_window } - LOG(INFO) << "CHECKING SLIDING WINDOW INDPTR"; - LOG(INFO) << page_indptr_sliding_window_h.back(); - LOG(INFO) << std::min(static_cast(block.page_ids.size()), sequences[d]->sliding_window_size / page_size_); - LOG(INFO) << sequences[d]->sliding_window_size; + + // For sliding window, the first page and last page will both be partially used page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + - std::min(static_cast(block.page_ids.size()), sequences[d]->sliding_window_size / page_size_)); - LOG(INFO) << "CHECKING INDICES"; + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { - LOG(INFO) << page_indices_h[i]; page_indices_sliding_window_h.push_back(page_indices_h[i]); } // set up the page indices properly by choosing the last (sliding_window_size / @@ -1032,9 +1033,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { : (block.seq_length - block.sink_length + block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (block.seq_length < sequences[d]->sliding_window_size) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(block.sliding_window_offset); + } sink_size_h.push_back(block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + + // If sliding window, we need to calculate the positional offset + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + } + // if (page_indices_sliding_window_h.size() > 0) { + // LOG(INFO) << "WOffset: "<< sliding_window_offset_h.back() << " SWIdx: "<< page_indptr_sliding_window_h.back() << " LastPgIdx: " << page_indices_sliding_window_h.back() << " LastPgLen: " << last_page_len_h.back() << " RopeOffset: " << k_rope_pos_offset_h.back() << " RopeOffsetSlide: " << k_rope_pos_offset_sliding_window_h.back(); + // } } else { // Blocks at maximum depth const Block& block = global_block_pool_[block_id]; @@ -1057,8 +1074,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Also add sliding window here? page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + - std::min(static_cast(num_pages), sequences[d]->sliding_window_size / page_size_)); + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); } @@ -1069,9 +1087,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { last_block.sliding_window_offset - 1) % page_size_ + 1); - sliding_window_offset_h.push_back(last_block.sliding_window_offset); + if (support_layer_sliding_window_) { + if (last_block.seq_length < sequences[d]->sliding_window_size) { + sliding_window_offset_h.push_back(0); + } else { + sliding_window_offset_h.push_back(last_block.seq_length % page_size_); + } + } else { + sliding_window_offset_h.push_back(last_block.sliding_window_offset); + } sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); + if (support_layer_sliding_window_) { + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + } } } } @@ -1833,7 +1862,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { */ void SlideWindowForSequence(Sequence* seq) { // - No action when the sequence is not enabled for sliding window. - if (seq->sliding_window_size == -1) { + if (seq->sliding_window_size == -1 || !support_sliding_window_) { return; } // - No action when the sequence length does not exceed the window size. @@ -1864,18 +1893,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t page_start_offset_after_sliding = (block.sliding_window_offset + length_to_slide) % page_size_; - if (!support_sliding_window_) { - block.sliding_window_offset = - page_idx_after_sliding * page_size_ + page_start_offset_after_sliding; - return; - } - // - Free the pages that are fully slidden. while (page_idx_after_sliding > num_sink_pages) { - if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) { - free_page_ids_.push_back(block.page_ids[num_sink_pages]); - } - block.page_ids.erase(block.page_ids.begin() + num_sink_pages); + if (block.page_ids[num_sink_pages] != kPagedKVCacheTempPageId) { + free_page_ids_.push_back(block.page_ids[num_sink_pages]); + } + block.page_ids.erase(block.page_ids.begin() + num_sink_pages); --page_idx_after_sliding; } // - The first sliding page after sliding is either the last sink page, @@ -2140,14 +2163,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray page_indptr; NDArray page_indices; NDArray length_info; + NDArray k_rope_pos; if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { page_indptr = page_indptr_sliding_window_on_depths_view_[d]; page_indices = page_indices_sliding_window_on_depths_view_[d]; length_info = layer_sliding_window_length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; } else { page_indptr = page_indptr_on_depths_view_[d]; page_indices = page_indices_on_depths_view_[d]; length_info = length_info_on_depths_view_[d]; + k_rope_pos = k_rope_pos_offset_view_[d]; } if (append_before_attn_ && !is_chain_on_depths_[d]) { @@ -2155,7 +2181,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, - length_info, k_rope_pos_offset_view_[d], q_rope_position_map_view_, + length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { @@ -2163,7 +2189,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_NOTNULL(f_decode); f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, - k_rope_pos_offset_view_[d], q_rope_position_map_view_, rope_mode_, + k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } else { @@ -2172,7 +2198,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, q_rope_position_map_view_, - k_rope_pos_offset_view_[d], /*causal=*/false, + k_rope_pos, /*causal=*/false, /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); } @@ -2334,6 +2360,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qo_indptr_on_depths_host_[d].size()); k_rope_pos_offset_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( &k_rope_pos_offset_on_depths_host_[d], d); + if (support_layer_sliding_window_) { + ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, + qo_indptr_on_depths_host_[d].size()); + k_rope_pos_offset_sliding_window_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( + &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); + } } // 7. cur_append_lengths_indptr cur_append_length_indptr_view_ = From 989d6ecdbd567adc0678101542b088cf657efaa7 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 9 Jun 2025 01:39:36 -0400 Subject: [PATCH 04/12] Fix correctness issues due to rope frequency --- src/runtime/relax_vm/paged_kv_cache.cc | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 02bb13473449..328acf655756 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1050,7 +1050,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); } // if (page_indices_sliding_window_h.size() > 0) { - // LOG(INFO) << "WOffset: "<< sliding_window_offset_h.back() << " SWIdx: "<< page_indptr_sliding_window_h.back() << " LastPgIdx: " << page_indices_sliding_window_h.back() << " LastPgLen: " << last_page_len_h.back() << " RopeOffset: " << k_rope_pos_offset_h.back() << " RopeOffsetSlide: " << k_rope_pos_offset_sliding_window_h.back(); + // LOG(INFO) << "WOffset: "<< sliding_window_offset_h.back() << " SWIdx: "<< page_indptr_sliding_window_h.back() << " LastPgIdx: " << page_indices_sliding_window_h.back() << " LastPgLen: " << last_page_len_h.back() << " RopeOffset: " << k_rope_pos_offset_h.back() << " RopeOffsetSlide: " << k_rope_pos_offset_sliding_window_h.back() << " Block Start: " << block.start_pos; // } } else { // Blocks at maximum depth @@ -2164,16 +2164,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray page_indices; NDArray length_info; NDArray k_rope_pos; + double rotary_theta; + double rotary_scale; + if (attn_kinds_[local_layer_id + layer_id_begin_offset_] == AttnKind::kMHASliding) { page_indptr = page_indptr_sliding_window_on_depths_view_[d]; page_indices = page_indices_sliding_window_on_depths_view_[d]; length_info = layer_sliding_window_length_info_on_depths_view_[d]; k_rope_pos = k_rope_pos_offset_sliding_window_view_[d]; + rotary_theta = 10000; + rotary_scale = 1; } else { page_indptr = page_indptr_on_depths_view_[d]; page_indices = page_indices_on_depths_view_[d]; length_info = length_info_on_depths_view_[d]; k_rope_pos = k_rope_pos_offset_view_[d]; + rotary_theta = rotary_theta_; + rotary_scale = rotary_scale_; } if (append_before_attn_ && !is_chain_on_depths_[d]) { @@ -2182,15 +2189,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, k_rope_pos, q_rope_position_map_view_, - tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale_, - rotary_theta_, sm_scale, attn_output, attn_lse, compute_stream_); + tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale, + rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, k_rope_pos, q_rope_position_map_view_, rope_mode_, - rotary_scale_, rotary_theta_, sm_scale, attn_output, attn_lse, + rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d @@ -2199,7 +2206,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr, page_indices, length_info, q_rope_position_map_view_, k_rope_pos, /*causal=*/false, - /*rotary_mode=*/rope_mode_, rotary_scale_, rotary_theta_, sm_scale, + /*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } From 9b46bf287e97bde6171807d80b9a3d185488cf8c Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Tue, 10 Jun 2025 00:02:54 -0400 Subject: [PATCH 05/12] Code cleanup --- src/runtime/relax_vm/paged_kv_cache.cc | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 328acf655756..355c3fc371c3 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -733,28 +733,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t attn_sink_size) final { // If per layer sliding window exists, enable sliding window for sequence CHECK(support_sliding_window_ || support_layer_sliding_window_) << "The KV cache does not support sliding window."; - // for (AttnKind attn_kind : attn_kinds_) { - // if (attn_kind == AttnKind::kMHASliding) { - // LOG(INFO) << "Found sliding"; - // } else if (attn_kind == AttnKind::kMHA) { - // LOG(INFO) << "Found non-sliding"; - // } else { - // LOG(INFO) << "Found other"; - // } - // } - // if (support_sliding_window_) { - // LOG(INFO) << "Sldiing window supported"; - // } else { - // LOG(INFO) << "Sldiing window not supported"; - // } - // if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHASliding) != attn_kinds_.end()) { - // LOG(INFO) << "Sliding layer found"; - // } else { - // LOG(INFO) << "Sliding layer not found"; - // } - // CHECK(!support_sliding_window_) << "The KV cache does not support sliding window."; - // LOG(INFO) << "Enabling sliding window"; - // LOG(INFO) << sliding_window_size; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) From ff16486cb99debf79404f6639dd731f90bceb9e8 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Tue, 10 Jun 2025 23:51:57 -0400 Subject: [PATCH 06/12] Add static cast --- src/runtime/relax_vm/paged_kv_cache.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 355c3fc371c3..d36cffb3057c 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -998,7 +998,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // For sliding window, the first page and last page will both be partially used page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); @@ -1027,9 +1027,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (support_layer_sliding_window_) { k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); } - // if (page_indices_sliding_window_h.size() > 0) { - // LOG(INFO) << "WOffset: "<< sliding_window_offset_h.back() << " SWIdx: "<< page_indptr_sliding_window_h.back() << " LastPgIdx: " << page_indices_sliding_window_h.back() << " LastPgLen: " << last_page_len_h.back() << " RopeOffset: " << k_rope_pos_offset_h.back() << " RopeOffsetSlide: " << k_rope_pos_offset_sliding_window_h.back() << " Block Start: " << block.start_pos; - // } } else { // Blocks at maximum depth const Block& block = global_block_pool_[block_id]; @@ -1049,11 +1046,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += block.seq_length; last_block_id = id; } - // Also add sliding window here? + page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - sequences[d]->sliding_window_size / page_size_ + (block.seq_length % sequences[d]->sliding_window_size ? 1 : 0) + static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); From 5a4e413704ee45a66e703ae4aff1617f4eea8356 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sun, 15 Jun 2025 22:02:32 -0400 Subject: [PATCH 07/12] Unity segfault fix --- src/runtime/vm/paged_kv_cache.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index f659f6ac45be..379bb849e8be 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -1002,8 +1002,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // For sliding window, the first page and last page will both be partially used page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); @@ -1054,8 +1054,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) + page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), + static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); From 085693be8bb3cff1ad9aa680603ced357da4fb0b Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sun, 15 Jun 2025 23:45:14 -0400 Subject: [PATCH 08/12] Fix bug --- src/runtime/vm/paged_kv_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 379bb849e8be..28f3e0558427 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -1077,9 +1077,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sliding_window_offset_h.push_back(last_block.sliding_window_offset); } sink_size_h.push_back(last_block.sink_length); - k_rope_pos_offset_h.push_back(block.start_pos); + k_rope_pos_offset_h.push_back(last_block.start_pos); if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, last_block.start_pos + last_block.seq_length - sequences[d]->sliding_window_size)); } } } From 7f1d1861de2401001675f00c2eff8d15d824c06a Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 16 Jun 2025 01:05:05 -0400 Subject: [PATCH 09/12] Test log --- src/runtime/vm/paged_kv_cache.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 28f3e0558427..66cde04948a3 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -1051,7 +1051,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += block.seq_length; last_block_id = id; } - + if (sequences[d]->sliding_window_size > 0) { + LOG(INFO) << "TEST"; + } page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), @@ -1077,9 +1079,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sliding_window_offset_h.push_back(last_block.sliding_window_offset); } sink_size_h.push_back(last_block.sink_length); - k_rope_pos_offset_h.push_back(last_block.start_pos); + k_rope_pos_offset_h.push_back(block.start_pos); if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, last_block.start_pos + last_block.seq_length - sequences[d]->sliding_window_size)); + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); } } } From 13880f1fd3b27a15e1f4bba2ab77cee84e606e22 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 16 Jun 2025 14:03:48 -0400 Subject: [PATCH 10/12] Change sliding window size to temporary constant --- src/runtime/vm/paged_kv_cache.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 66cde04948a3..c52d3b110673 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -1003,7 +1003,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // For sliding window, the first page and last page will both be partially used page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) + static_cast(1024 / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); @@ -1017,7 +1017,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_size_ + 1); if (support_layer_sliding_window_) { - if (block.seq_length < sequences[d]->sliding_window_size) { + if (block.seq_length < 1024) { sliding_window_offset_h.push_back(0); } else { sliding_window_offset_h.push_back(block.seq_length % page_size_); @@ -1030,7 +1030,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // If sliding window, we need to calculate the positional offset if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - 1024)); } } else { // Blocks at maximum depth @@ -1051,13 +1051,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { total_seq_length += block.seq_length; last_block_id = id; } - if (sequences[d]->sliding_window_size > 0) { - LOG(INFO) << "TEST"; - } page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(sequences[d]->sliding_window_size / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) + static_cast(1024 / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) )); for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); @@ -1070,7 +1067,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { page_size_ + 1); if (support_layer_sliding_window_) { - if (last_block.seq_length < sequences[d]->sliding_window_size) { + if (last_block.seq_length < 1024) { sliding_window_offset_h.push_back(0); } else { sliding_window_offset_h.push_back(last_block.seq_length % page_size_); @@ -1081,7 +1078,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - sequences[d]->sliding_window_size)); + k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - 1024)); } } } From 675d0cf39a2c7a6ac1553a1bcca7bc4d9dc1691c Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 16 Jun 2025 21:27:28 -0400 Subject: [PATCH 11/12] Fix lint --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 22 ++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index f74550c0f591..a1d742739aca 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -302,7 +302,7 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, - attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -378,8 +378,16 @@ def __init__( # pylint: disable=too-many-locals dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim, - v_head_dim=v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim, + qk_head_dim=( + qk_head_dim + if (attn_kind == "mha" or isinstance(attn_kind, List)) + else mla_original_qk_head_dim + ), + v_head_dim=( + v_head_dim + if (attn_kind == "mha" or isinstance(attn_kind, List)) + else mla_original_v_head_dim + ), target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, ) @@ -488,7 +496,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, - attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla","mha_sliding"]]], + attn_kind: Union[Literal["mha", "mla"], List[Literal["mha", "mla", "mha_sliding"]]], max_batch_size: tir.Var, max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, @@ -560,9 +568,11 @@ def __init__( # pylint: disable=too-many-locals The target to build the model to. """ if isinstance(attn_kind, List): - attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] + attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] else: - attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] + attn_kind = [ + int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers) + ] bb = rx.BlockBuilder.current() args = [ rx.ShapeExpr( From 95dff313a7c4bc403308608afbddec85961e8510 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Mon, 16 Jun 2025 22:55:45 -0400 Subject: [PATCH 12/12] Fix c++ lint --- src/runtime/vm/paged_kv_cache.cc | 119 ++++++++++++++++++------------- 1 file changed, 71 insertions(+), 48 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index c52d3b110673..2af4b19b06b1 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -307,8 +307,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { v_head_dim_(v_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), - support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end() ? false : support_sliding_window), - support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), AttnKind::kMHASliding) != attn_kinds.end()), + support_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), + AttnKind::kMHASliding) != attn_kinds.end() + ? false + : support_sliding_window), + support_layer_sliding_window_(std::find(attn_kinds.begin(), attn_kinds.end(), + AttnKind::kMHASliding) != attn_kinds.end()), attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), @@ -737,7 +741,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { // If per layer sliding window exists, enable sliding window for sequence - CHECK(support_sliding_window_ || support_layer_sliding_window_) << "The KV cache does not support sliding window."; + CHECK(support_sliding_window_ || support_layer_sliding_window_) + << "The KV cache does not support sliding window."; auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; CHECK_GE(attn_sink_size, 0) @@ -959,13 +964,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { HostMemoryVector& qo_indptr_h = qo_indptr_on_depths_host_[d]; HostMemoryVector& page_indptr_h = page_indptr_on_depths_host_[d]; HostMemoryVector& page_indices_h = page_indices_on_depths_host_[d]; - HostMemoryVector& page_indptr_sliding_window_h = page_indptr_sliding_window_on_depths_host_[d]; - HostMemoryVector& page_indices_sliding_window_h = page_indices_sliding_window_on_depths_host_[d]; + HostMemoryVector& page_indptr_sliding_window_h = + page_indptr_sliding_window_on_depths_host_[d]; + HostMemoryVector& page_indices_sliding_window_h = + page_indices_sliding_window_on_depths_host_[d]; HostMemoryVector& last_page_len_h = last_page_len_on_depths_host_[d]; HostMemoryVector& sliding_window_offset_h = sliding_window_offset_on_depths_host_[d]; HostMemoryVector& sink_size_h = sink_size_on_depths_host_[d]; HostMemoryVector& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d]; - HostMemoryVector& k_rope_pos_offset_sliding_window_h = k_rope_pos_offset_sliding_window_on_depths_host_[d]; + HostMemoryVector& k_rope_pos_offset_sliding_window_h = + k_rope_pos_offset_sliding_window_on_depths_host_[d]; qo_indptr_h.clear(); page_indptr_h.clear(); page_indices_h.clear(); @@ -1002,10 +1010,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // For sliding window, the first page and last page will both be partially used page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(1024 / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) - )); - for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(1024 / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); + i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); } // set up the page indices properly by choosing the last (sliding_window_size / @@ -1030,7 +1040,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // If sliding window, we need to calculate the positional offset if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - 1024)); + k_rope_pos_offset_sliding_window_h.push_back( + std::max(0, block.start_pos + block.seq_length - 1024)); } } else { // Blocks at maximum depth @@ -1053,10 +1064,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } page_indptr_h.push_back(page_indptr_h.back() + num_pages); page_indptr_sliding_window_h.push_back( - page_indptr_sliding_window_h.back() + std::min(static_cast(block.page_ids.size()), - static_cast(1024 / page_size_ + (block.seq_length % page_size_ ? 1 : 0)) - )); - for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); i < static_cast(page_indices_h.size()); i++) { + page_indptr_sliding_window_h.back() + + std::min(static_cast(block.page_ids.size()), + static_cast(1024 / page_size_ + + (block.seq_length % page_size_ ? 1 : 0)))); + for (int i = page_indices_h.size() - page_indptr_sliding_window_h.back(); + i < static_cast(page_indices_h.size()); i++) { page_indices_sliding_window_h.push_back(page_indices_h[i]); } const Block& last_block = global_block_pool_[last_block_id]; @@ -1078,7 +1091,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sink_size_h.push_back(last_block.sink_length); k_rope_pos_offset_h.push_back(block.start_pos); if (support_layer_sliding_window_) { - k_rope_pos_offset_sliding_window_h.push_back(std::max(0, block.start_pos + block.seq_length - 1024)); + k_rope_pos_offset_sliding_window_h.push_back( + std::max(0, block.start_pos + block.seq_length - 1024)); } } } @@ -1265,7 +1279,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray pages = pages_[local_layer_id]; CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || attn_kinds_[layer_id] == AttnKind::kMHASliding); + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || + attn_kinds_[layer_id] == AttnKind::kMHASliding); // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1883,7 +1898,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - The first sliding page after sliding is either the last sink page, // or the page next to the last sink page. ICHECK(page_idx_after_sliding == num_sink_pages - 1 || - page_idx_after_sliding == num_sink_pages); + page_idx_after_sliding == num_sink_pages); // - Update the length of the sequence and the block. seq->seq_length = seq->sliding_window_size; @@ -1893,9 +1908,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK_GE(block.seq_length, block.sink_length); ICHECK_GE(block.sliding_window_offset, block.sink_length); ICHECK_EQ( - (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / - page_size_, - block.page_ids.size()); + (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / + page_size_, + block.page_ids.size()); } /*! @@ -2001,7 +2016,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_ || !support_layer_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + CHECK(!support_sliding_window_ || !support_layer_sliding_window_) + << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { if (f_attention_decode_ != nullptr && f_attention_decode_->backend_kind == AttnBackendKind::kFlashInfer) { @@ -2119,9 +2135,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool MHACrossAttnInternal(int64_t local_layer_id, NDArray q_data, NDArray o_data, NDArray lse_data, double sm_scale, bool is_first_kernel) { std::unique_ptr& f_prefill = - (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_prefill_ : f_attention_prefill_sliding_window_; + (!support_sliding_window_ && + attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) + ? f_attention_prefill_ + : f_attention_prefill_sliding_window_; std::unique_ptr& f_decode = - (!support_sliding_window_ && attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_decode_ : f_attention_decode_sliding_window_; + (!support_sliding_window_ && + attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) + ? f_attention_decode_ + : f_attention_decode_sliding_window_; CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; bool cross_attn_computed = false; @@ -2165,26 +2187,22 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (append_before_attn_ && !is_chain_on_depths_[d]) { ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); f_attention_prefill_with_tree_mask_paged_kv_->MHA( - q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr, page_indices, - length_info, k_rope_pos, q_rope_position_map_view_, - tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d], rope_mode_, rotary_scale, - rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); + q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, + length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], + tree_attn_mask_view_[d], rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, + attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d ICHECK_NOTNULL(f_decode); - f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, - page_indices, length_info, - k_rope_pos, q_rope_position_map_view_, rope_mode_, - rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, - compute_stream_); + f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, + k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta, + sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d ICHECK_NOTNULL(f_prefill); - f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], - page_indptr, page_indices, - length_info, q_rope_position_map_view_, - k_rope_pos, /*causal=*/false, + f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, + page_indices, length_info, q_rope_position_map_view_, k_rope_pos, + /*causal=*/false, /*rotary_mode=*/rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } @@ -2302,15 +2320,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (support_layer_sliding_window_) { // 5. page_indptr_sliding_window_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); + ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), + qo_indptr_on_depths_host_[d].size()); page_indptr_sliding_window_on_depths_view_[d] = - aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_sliding_window_on_depths_host_[d], d); + aux_data_manager_->CopyPageIndptrOnDepthAsync( + &page_indptr_sliding_window_on_depths_host_[d], d); } // 6. page_indices_sliding_window_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), page_indptr_sliding_window_on_depths_host_[d].back()); + ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), + page_indptr_sliding_window_on_depths_host_[d].back()); page_indices_sliding_window_on_depths_view_[d] = - aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_sliding_window_on_depths_host_[d], d); + aux_data_manager_->CopyPageIndicesOnDepthAsync( + &page_indices_sliding_window_on_depths_host_[d], d); } } // 7. length_info_on_depths @@ -2334,11 +2356,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (support_layer_sliding_window_) { - layer_sliding_window_length_info_on_depths_view_[d] = aux_data_manager_->CopyLengthInfoOnDepthAsync( - &last_page_len_on_depths_host_[d], &sliding_window_offset_on_depths_host_[d], - &sink_size_on_depths_host_[d], d); + layer_sliding_window_length_info_on_depths_view_[d] = + aux_data_manager_->CopyLengthInfoOnDepthAsync(&last_page_len_on_depths_host_[d], + &sliding_window_offset_on_depths_host_[d], + &sink_size_on_depths_host_[d], d); } - } // 6. k_rope_pos_offset_on_depths for (int d = 0; d < num_depths_; ++d) { @@ -2349,8 +2371,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (support_layer_sliding_window_) { ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, qo_indptr_on_depths_host_[d].size()); - k_rope_pos_offset_sliding_window_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( - &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); + k_rope_pos_offset_sliding_window_view_[d] = + aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( + &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); } } // 7. cur_append_lengths_indptr