From 1d202799fa5eb5240d95a7d52e8cb475b297216e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 29 Jan 2025 23:00:43 -0500 Subject: [PATCH] [Runtime][KVCache] Initial interface setup for MLA This PR introduces the initial KV cache interface setup for multi-head latent attention in DeepSeek models. Some interface implementations are marked todo for implementation in the soon future. --- src/runtime/relax_vm/kv_state.h | 63 +++++ src/runtime/relax_vm/paged_kv_cache.cc | 313 +++++++++++++++++++++---- 2 files changed, 330 insertions(+), 46 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 7df3215d088e..77c17d1c555f 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -181,6 +181,69 @@ class AttentionKVCacheObj : public KVStateObj { virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, NDArray o_data, double attn_score_scaling_factor) = 0; + /*! + * \brief Compute attention with Q/K/V data. + * \param layer_id The model layer where the attention compute happens. + * \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)` + * \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)` + * \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)` + * \param mask The input mask data, in layout `(total_sqr_length)`. + * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. + */ + virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data, + NDArray v_data, Optional mask, NDArray o_data, + double attn_score_scaling_factor) = 0; + + /*! + * \brief Compute multi-head latent attention after applying weight absorption. + * \param layer_id The model layer where the attention compute happens. + * \param q_data The input Q data, in layout `(total_length, num_qo_heads, qk_head_dim)` + * \param compressed_kv_data The compressed latent KV data, in layout + * `(total_length, num_kv_heads, kv_lora_rank)` + * \param k_pe_data The positional embedding part of K data, in layout + * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim` + * equals qk_head_dim + * \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. + */ + virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, + NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) = 0; + + /*! + * \brief Compute multi-head latent attention in normal style. + * \param layer_id The model layer where the attention compute happens. + * \param q_data The input Q data, in layout + * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)` + * \param k_data The input K data, in layout + * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)` + * \param v_data The input V data, in layout + * `(total_length, num_qo_heads, v_head_dim)` + * \param compressed_kv_data The compressed latent KV data, in layout + * `(total_length, num_kv_heads, kv_lora_rank)` + * \param k_pe_data The positional embedding part of K data, in layout + * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + qk_rope_head_dim` + * equals qk_head_dim + * \param o_data The output O data, in layout `(total_length, num_qo_heads, v_head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. + */ + virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, + double attn_score_scaling_factor) = 0; + + /*! + * \brief Compute linear attention with Q/K/V data. + * \param layer_id The model layer where the attention compute happens. + * \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`. + * \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`. + * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`. + * \param attn_score_scaling_factor The additional attention scaling factor. + * \sa AttentionKVCache::Attention + */ + virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + double attn_score_scaling_factor) = 0; + /************** Positions **************/ /*! diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 81c55bfcb645..8e5dfb4bd81e 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -64,6 +64,33 @@ constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ constexpr const int kPagedKVCacheTempPageId = -1; +/*! + * \brief The supported attention kinds in PagedKVCache. + * "MHA" means multi-head attention, multi-query attention and grouped query attention in general. + * "MLA" means multi-head latent attention. + * "LinearAttn" means linear attention. + */ +enum class AttnKind : int { + kMHA = 0, + kMLA = 1, + kLinearAttn = 2, +}; + +ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, + int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, + int64_t v_head_dim, int64_t qk_rope_head_dim) { + if (attn_kind == AttnKind::kMHA) { + // Ignore v_head_dim since multi-head attention requires K/V to have the same head dim. + return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}; + } else if (attn_kind == AttnKind::kMLA) { + return {num_total_pages, num_kv_heads, page_size, qk_head_dim + qk_rope_head_dim}; + } else if (attn_kind == AttnKind::kLinearAttn) { + return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; + } + ICHECK(false); + throw; +} + /*! * \brief The block structure in paged KV cache with common prefix support. * Each block contains a list of pages for cached KV data. @@ -940,13 +967,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The number of key/value heads in the model. */ const int64_t num_kv_heads_; /*! \brief The number of features each head has. */ - const int64_t head_dim_; + const int64_t qk_head_dim_; + /*! + * \brief The number of features each head has for V. + * For layers that use multi-head attention, this field is overriden by qk_head_dim. + */ + const int64_t v_head_dim_; + /*! + * \brief The number of features each head has for RoPE in multi-head latent attention. + * This field is ignored for non-MLA. + */ + const int64_t qk_rope_head_dim_; /*! \brief The number of total pages allocated in KV cache. */ const int64_t num_total_pages_; /*! \brief The maximum total sequence length in a prefill. */ const int64_t prefill_chunk_size_; /*! \brief A boolean flag indicating if the KV cache supports sliding window. */ const bool support_sliding_window_; + /*! \brief The attention kinds for each layer. */ + const std::vector attn_kinds_; /*! \brief The RoPE application mode of KV cache.*/ const RoPEMode rope_mode_; @@ -967,7 +1006,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM as a whole NDArray. * pages_ will contain tensor view of each layer. * Otherwise, pages_ has `num_layers` NDArrays, each of them - * has layout (num_pages, 2, num_heads, page_size, head_dim). + * has layout (num_pages, 2, num_heads, page_size, qk_head_dim). * Along on the "2" dimension, index 0 stands for K and 1 stands for V. */ std::vector pages_; @@ -1086,6 +1125,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector tree_attn_mn_indptr_view_; PackedFunc f_transpose_append_; + PackedFunc f_transpose_append_mla_; Optional f_transfer_kv_; Optional f_transfer_kv_page_to_page_ = NullOpt; PackedFunc f_compact_copy_; @@ -1102,8 +1142,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_attention_prefill_end_forward_; Optional f_attention_decode_begin_forward_; Optional f_attention_decode_end_forward_; + PackedFunc f_mla_prefill_; + PackedFunc f_mla_decode_; + PackedFunc f_mla_prefill_ragged_normal_; + PackedFunc f_mla_prefill_ragged_absorbed_; PackedFunc f_merge_inplace_; PackedFunc f_split_rotary_; + PackedFunc f_separate_rotary_; PackedFunc f_copy_single_page_; Optional f_debug_get_kv_; @@ -1120,37 +1165,45 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */ explicit PagedAttentionKVCacheObj( int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset, // - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t reserved_num_seqs, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, + int64_t qk_rope_head_dim, std::vector attn_kinds, int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size, bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta, Optional rope_ext_factors, bool enable_kv_transfer, DLDataType dtype, Device device, - PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill, - PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window, - PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged, - PackedFunc f_attention_prefill_with_tree_mask, + PackedFunc f_transpose_append, PackedFunc f_transpose_append_mla, PackedFunc f_compact_copy, + PackedFunc f_attention_prefill, PackedFunc f_attention_decode, + PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window, + PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask, PackedFunc f_attention_prefill_with_tree_mask_paged_kv, Optional f_attention_prefill_ragged_begin_forward, Optional f_attention_prefill_ragged_end_forward, Optional f_attention_prefill_begin_forward, Optional f_attention_prefill_end_forward, Optional f_attention_decode_begin_forward, - Optional f_attention_decode_end_forward, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_debug_get_kv) + Optional f_attention_decode_end_forward, PackedFunc f_mla_prefill, + PackedFunc f_mla_decode, PackedFunc f_mla_prefill_ragged_normal, + PackedFunc f_mla_prefill_ragged_absorbed, PackedFunc f_merge_inplace, + PackedFunc f_split_rotary, PackedFunc f_separate_rotary, PackedFunc f_copy_single_page, + Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), layer_id_begin_offset_(layer_id_begin_offset), num_qo_heads_(num_qo_heads), num_kv_heads_(num_kv_heads), - head_dim_(head_dim), + qk_head_dim_(qk_head_dim), + v_head_dim_(v_head_dim), + qk_rope_head_dim_(qk_rope_head_dim), num_total_pages_(num_total_pages), prefill_chunk_size_(prefill_chunk_size), support_sliding_window_(support_sliding_window), + attn_kinds_(std::move(attn_kinds)), rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? RoPEMode::kInline : rope_mode), rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), rope_ext_factors_(std::move(rope_ext_factors)), f_transpose_append_(std::move(f_transpose_append)), + f_transpose_append_mla_(std::move(f_transpose_append_mla)), f_compact_copy_(std::move(f_compact_copy)), f_attention_prefill_(std::move(f_attention_prefill)), f_attention_decode_(std::move(f_attention_decode)), @@ -1167,24 +1220,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)), f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)), f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)), + f_mla_prefill_(std::move(f_mla_prefill)), + f_mla_decode_(std::move(f_mla_decode)), + f_mla_prefill_ragged_normal_(std::move(f_mla_prefill_ragged_normal)), + f_mla_prefill_ragged_absorbed_(std::move(f_mla_prefill_ragged_absorbed)), f_merge_inplace_(std::move(f_merge_inplace)), f_split_rotary_(std::move(f_split_rotary)), + f_separate_rotary_(std::move(f_separate_rotary)), f_copy_single_page_(std::move(f_copy_single_page)), f_debug_get_kv_(std::move(f_debug_get_kv)), device_(device) { pages_.reserve(num_layers); if (enable_kv_transfer) { + // For now, KV transfer only supports MHA. + for (AttnKind attn_kind : attn_kinds_) { + CHECK(attn_kind == AttnKind::kMHA); + } CHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem") != nullptr) << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when compiling TVM."; const PackedFunc* f_nvshmem_empty = runtime::Registry::Get("runtime.disco.nvshmem.empty"); ICHECK_NOTNULL(f_nvshmem_empty); nvshmem_pages_ = (*f_nvshmem_empty)( - ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, head_dim}), dtype, + ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}), dtype, device); for (int i = 0; i < num_layers; ++i) { pages_.push_back(nvshmem_pages_.CreateView( - {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_}, nvshmem_pages_->dtype, - i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ * + {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype, + i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ * nvshmem_pages_.DataType().bytes())); } @@ -1197,8 +1259,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr; } else { for (int i = 0; i < num_layers; ++i) { - pages_.push_back( - NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device)); + ShapeTuple kv_cache_shape = GetKVCacheShape( + attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, reserved_num_seqs, + num_kv_heads, page_size, qk_head_dim, v_head_dim, qk_rope_head_dim); + pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device)); } } @@ -1274,13 +1338,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } temp_attn_q_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device); + NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); temp_attn_k_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype, device); + NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, dtype, device); temp_attn_v_device_ = - NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype, device); + NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, device); temp_attn_output_device_ = - NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device); + NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, dtype, device); temp_attn_scores_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); merged_attn_scores_device_ = @@ -2019,8 +2083,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK(qkv_data.DataType() == pages.DataType()); CHECK(o_data.DataType() == pages.DataType()); - // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim) - // o_data: (num_total_length, num_qo_heads, head_dim) + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA); + // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) + // o_data: (num_total_length, num_qo_heads, qk_head_dim) CHECK_EQ(qkv_data->ndim, 3); CHECK_EQ(o_data->ndim, 3); @@ -2033,7 +2098,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - CHECK_EQ(qkv_data->shape[2], head_dim_); + CHECK_EQ(qkv_data->shape[2], qk_head_dim_); int64_t total_seq_length = 0; for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; @@ -2044,11 +2109,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. ICHECK(!dirty_aux_data_device_); - NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, head_dim_}, + NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, qkv_data->dtype); - NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, + NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, qkv_data->dtype); - NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_}, + NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, qk_head_dim_}, qkv_data->dtype); NDArray qkv_data_view = qkv_data; @@ -2057,7 +2122,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { qkv_data_view = qkv_data.CreateView( {total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype); o_data_view = - o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype); + o_data.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, qkv_data->dtype); } // Part 2. Split fused qkv and apply rotary embedding to q/k data. if (transfer_kv_) { @@ -2105,6 +2170,28 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + Optional mask, NDArray o_data, + double attn_score_scaling_factor) final { + // Todo(ruihang): implement it + } + + void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray compressed_kv_data, NDArray k_pe_data, + NDArray o_data, double attn_score_scaling_factor) { + // Todo(ruihang): implement it + } + + void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, + double attn_score_scaling_factor) { + // Todo(ruihang): implement it + } + + void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, + double attn_score_scaling_factor) { + // Todo(ruihang): implement it + } + void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) final { CHECK_EQ(seq_ids.size(), leaf_indices.size()) << "The given seq_ids and leaf_indices have different size."; @@ -2216,9 +2303,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; - // k/v_data: (num_layers, seq_length, num_kv_heads, head_dim) + // k/v_data: (num_layers, seq_length, num_kv_heads, qk_head_dim) static constexpr const char* error_msg = - "DebugGetKV expects the k_data in layout (num_layers, seq_length, num_kv_heads, head_dim)."; + "DebugGetKV expects the k_data in layout (num_layers, seq_length, num_kv_heads, " + "qk_head_dim)."; std::vector vec_kv_data = {&k_data, &v_data}; for (const NDArray* data_ptr : vec_kv_data) { CHECK_EQ((*data_ptr)->ndim, 4) << error_msg; @@ -2228,7 +2316,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << error_msg << " The sequence length mismatches."; CHECK_EQ((*data_ptr)->shape[2], num_kv_heads_) << error_msg << " The number of heads mismatches."; - CHECK_EQ((*data_ptr)->shape[3], head_dim_) + CHECK_EQ((*data_ptr)->shape[3], qk_head_dim_) << error_msg << " The number of head features mismatches."; } @@ -2250,6 +2338,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { + CHECK(attn_kinds_[layer_id] == AttnKind::kMHA) << "Only MHA is supported for DebugGetKV"; f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, v_data, layer_id); } } @@ -2649,7 +2738,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { temp_float_attn_workspace_, temp_int_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, - num_kv_heads_, head_dim_, copy_stream_); + num_kv_heads_, qk_head_dim_, copy_stream_); } } for (int d = 0; d < num_depths_; ++d) { @@ -2661,15 +2750,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_attention_decode_begin_forward_.value()( d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), - last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, - page_size_, + last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, + qk_head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( /*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(), static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, - num_kv_heads_, head_dim_, page_size_, copy_stream_); + num_kv_heads_, qk_head_dim_, page_size_, copy_stream_); } } } @@ -2893,7 +2982,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // 16. Create view for temporary arrays for attention computation. temp_attn_output_view_ = temp_attn_output_device_.CreateView( - {total_append_length, num_qo_heads_, head_dim_}, temp_attn_output_device_->dtype); + {total_append_length, num_qo_heads_, qk_head_dim_}, temp_attn_output_device_->dtype); temp_attn_scores_view_ = temp_attn_scores_device_.CreateView( {total_append_length, num_qo_heads_}, temp_attn_scores_device_->dtype); merged_attn_scores_view_ = merged_attn_scores_device_.CreateView( @@ -2964,6 +3053,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") enable_kv_transfer = args[29]; } + std::vector attn_kinds(/*size=*/layer_indptr_tuple[num_groups], + /*value=*/AttnKind::kMHA); + CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -2975,13 +3067,18 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") // When sliding window is enabled, each sequence may use two more pages at most. num_total_pages += reserved_num_seqs * 2; } + // NOTE: We will remove this legacy construction after finishing the transition phase. + // Some `PackedFunc()` here are placeholders that will be filled. ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, - reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // - enable_kv_transfer, init->dtype, init->device, // - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, num_total_pages, + prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, + rotary_theta, + std::move(rope_ext_factors), // + enable_kv_transfer, init->dtype, init->device, // + std::move(f_transpose_append), PackedFunc(), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), std::move(f_attention_prefill_with_tree_mask_paged_kv), @@ -2989,7 +3086,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), + PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), + std::move(f_split_rotary), PackedFunc(), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); @@ -3040,6 +3138,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") enable_kv_transfer = args[23]; } + std::vector attn_kinds(/*size=*/layer_indptr_tuple[num_groups], + /*value=*/AttnKind::kMHA); + CHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; @@ -3051,18 +3152,138 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") // When sliding window is enabled, each sequence may use two more pages at most. num_total_pages += reserved_num_seqs * 2; } + // NOTE: We will remove this legacy construction after finishing the transition phase. + // Some `PackedFunc()` here are placeholders that will be filled. ObjectPtr n = make_object( page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, head_dim, - reserved_num_seqs, num_total_pages, prefill_chunk_size, support_sliding_window, - RoPEMode(rope_mode), rotary_scale, rotary_theta, std::move(rope_ext_factors), // - enable_kv_transfer, init->dtype, init->device, // - std::move(f_transpose_append), std::move(f_compact_copy), std::move(f_attention_prefill), - std::move(f_attention_decode), std::move(f_attention_prefill_sliding_window), + head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, num_total_pages, + prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, + rotary_theta, + std::move(rope_ext_factors), // + enable_kv_transfer, init->dtype, init->device, // + std::move(f_transpose_append), PackedFunc(), std::move(f_compact_copy), + std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), std::move(f_attention_prefill_with_tree_mask), // std::move(f_attention_prefill_with_tree_mask_paged_kv), // NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, // - std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), + PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), std::move(f_merge_inplace), + std::move(f_split_rotary), PackedFunc(), std::move(f_copy_single_page), + std::move(f_debug_get_kv)); + *rv = AttentionKVCache(std::move(n)); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla") + .set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 39) << "Invalid number of KV cache constructor args."; + ShapeTuple cache_config = args[0]; + ShapeTuple layer_indptr_tuple = args[1]; + int num_groups = 1; + int group_id = 0; + if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { + // In the Disco worker thread + num_groups = disco_worker->num_groups; + group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); + } + CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; + int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; + int64_t num_qo_heads = args[2]; + int64_t num_kv_heads = args[3]; + int64_t qk_head_dim = args[4]; + int64_t v_head_dim = args[5]; + int64_t qk_rope_head_dim = args[6]; + IntTuple attn_kinds = args[7]; + int rope_mode = args[8]; + double rotary_scale = args[9]; + double rotary_theta = args[10]; + NDArray init = args[11]; + PackedFunc f_transpose_append = args[12]; + PackedFunc f_transpose_append_mla = args[13]; + PackedFunc f_attention_prefill = args[14]; + PackedFunc f_attention_decode = args[15]; + PackedFunc f_attention_prefill_sliding_window = args[16]; + PackedFunc f_attention_decode_sliding_window = args[17]; + PackedFunc f_attention_prefill_ragged = args[18]; + Optional f_attention_prefill_ragged_begin_forward = NullOpt; + Optional f_attention_prefill_ragged_end_forward = NullOpt; + Optional f_attention_prefill_begin_forward = NullOpt; + Optional f_attention_prefill_end_forward = NullOpt; + Optional f_attention_decode_begin_forward = NullOpt; + Optional f_attention_decode_end_forward = NullOpt; + PackedFunc f_mla_prefill = args[25]; + PackedFunc f_mla_decode = args[26]; + PackedFunc f_mla_prefill_ragged_normal = args[27]; + PackedFunc f_mla_prefill_ragged_absorbed = args[28]; + PackedFunc f_merge_inplace = args[29]; + PackedFunc f_split_rotary = args[30]; + PackedFunc f_separate_rotary = args[31]; + PackedFunc f_copy_single_page = args[32]; + Optional f_debug_get_kv = args[33]; + PackedFunc f_compact_copy = args[34]; + PackedFunc f_attention_prefill_with_tree_mask = args[35]; + PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[36]; + Optional rope_ext_factors = NullOpt; + bool enable_kv_transfer = false; + + if (args[37].IsObjectRef()) { + rope_ext_factors = args[37].AsObjectRef(); + } + enable_kv_transfer = args[38]; + + auto f_convert_optional_packed_func = [&args](int arg_idx) -> Optional { + if (args[arg_idx].IsObjectRef()) { + return args[arg_idx].AsObjectRef(); + } + return NullOpt; + }; + f_attention_prefill_ragged_begin_forward = f_convert_optional_packed_func(19); + f_attention_prefill_ragged_end_forward = f_convert_optional_packed_func(20); + f_attention_prefill_begin_forward = f_convert_optional_packed_func(21); + f_attention_prefill_end_forward = f_convert_optional_packed_func(22); + f_attention_decode_begin_forward = f_convert_optional_packed_func(23); + f_attention_decode_end_forward = f_convert_optional_packed_func(24); + + std::vector attn_kinds_vec; + attn_kinds_vec.reserve(attn_kinds.size()); + for (int64_t attn_kind : attn_kinds) { + attn_kinds_vec.push_back(static_cast(attn_kind)); + } + + CHECK_EQ(cache_config.size(), 5); + int64_t reserved_num_seqs = cache_config[0]; + int64_t total_token_capacity = cache_config[1]; + int64_t prefill_chunk_size = cache_config[2]; + int64_t page_size = cache_config[3]; + bool support_sliding_window = cache_config[4]; + int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; + if (support_sliding_window) { + // When sliding window is enabled, each sequence may use two more pages at most. + num_total_pages += reserved_num_seqs * 2; + } + // NOTE: We will remove this legacy construction after finishing the transition phase. + // Some `PackedFunc()` here are placeholders that will be filled. + ObjectPtr n = make_object( + page_size, num_layers, layer_id_begin_offset, num_qo_heads, num_kv_heads, qk_head_dim, + v_head_dim, qk_rope_head_dim, attn_kinds_vec, reserved_num_seqs, num_total_pages, + prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), rotary_scale, + rotary_theta, + std::move(rope_ext_factors), // + enable_kv_transfer, init->dtype, init->device, // + std::move(f_transpose_append), std::move(f_transpose_append_mla), + std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode), + std::move(f_attention_prefill_sliding_window), + std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), + std::move(f_attention_prefill_with_tree_mask), // + std::move(f_attention_prefill_with_tree_mask_paged_kv), // + std::move(f_attention_prefill_ragged_begin_forward), + std::move(f_attention_prefill_ragged_end_forward), + std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), + std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), + std::move(f_mla_prefill), std::move(f_mla_decode), std::move(f_mla_prefill_ragged_normal), + std::move(f_mla_prefill_ragged_absorbed), std::move(f_merge_inplace), + std::move(f_split_rotary), std::move(f_separate_rotary), std::move(f_copy_single_page), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); });