Skip to content

Commit

Permalink
[Serving] PagedKVCache tree-attention integration (#2487)
Browse files Browse the repository at this point in the history
This PR integrates the recent support of tree-attention in PagedKVCache
into the speculative decoding in MLC. Right now only chains are
supported. Tree-based speculative decoding is on the project road map
and we are planning to support it in recent future.
  • Loading branch information
MasterJH5574 authored Jun 4, 2024
1 parent 90170e6 commit c0c33a5
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 26 deletions.
27 changes: 22 additions & 5 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,19 @@ class BatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(total_verify_length);
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), total_verify_length);

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
NDArray logits =
models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, verify_lengths);
NDArray logits = models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids,
verify_lengths, token_tree_parent_ptr);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
ICHECK_EQ(logits->ndim, 3);
ICHECK_EQ(logits->shape[0], 1);
Expand Down Expand Up @@ -138,7 +148,11 @@ class BatchVerifyActionObj : public EngineActionObj {
// by the draft model but not added into the draft model's KV cache.
// In this case, an additional batch decode step is needed for these requests.
std::vector<int64_t> fully_accepted_rsentries;
std::vector<int64_t> verify_model_seq_internal_ids;
std::vector<int64_t> accepted_token_tree_leaf_nodes;
fully_accepted_rsentries.reserve(num_rsentries);
verify_model_seq_internal_ids.reserve(num_rsentries);
accepted_token_tree_leaf_nodes.reserve(num_rsentries);

for (int i = 0; i < num_rsentries; ++i) {
const std::vector<SampleResult>& sample_results = sample_results_arr[i];
Expand All @@ -154,12 +168,13 @@ class BatchVerifyActionObj : public EngineActionObj {
accept_length);
int rollback_length =
std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);
// rollback kv cache
// Commit accepted tokens to the "verify_model", rollback kv cache
// in the "draft_model".
// NOTE: when number of small models is more than 1 (in the future),
// it is possible to re-compute prefill for the small models.
verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);
accepted_token_tree_leaf_nodes.push_back(accept_length - 1);
if (rollback_length > 0) {
models_[verify_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length);
// The last accepted token is not yet added into the draft model.
// Therefore, the rollback length for the draft model is one less.
models_[draft_model_id_]->PopNFromKVCache(
Expand All @@ -168,6 +183,8 @@ class BatchVerifyActionObj : public EngineActionObj {
fully_accepted_rsentries.push_back(i);
}
}
models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(
verify_model_seq_internal_ids, accepted_token_tree_leaf_nodes);

if (!fully_accepted_rsentries.empty()) {
// - Run a step of batch decode for requests whose drafts are fully accepted.
Expand Down
25 changes: 21 additions & 4 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(cum_verify_lengths.back());
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), cum_verify_lengths.back());

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
embeddings, request_internal_ids, verify_lengths);
embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
ICHECK_EQ(logits->ndim, 2);
Expand Down Expand Up @@ -141,7 +151,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// by the draft model but not added into the draft model's KV cache.
// In this case, an additional batch decode step is needed for these requests.
std::vector<int64_t> fully_accepted_rsentries;
std::vector<int64_t> verify_model_seq_internal_ids;
std::vector<int64_t> accepted_token_tree_leaf_nodes;
fully_accepted_rsentries.reserve(num_rsentries);
verify_model_seq_internal_ids.reserve(num_rsentries);
accepted_token_tree_leaf_nodes.reserve(num_rsentries);

std::vector<int> last_accepted_hidden_positions;
last_accepted_hidden_positions.reserve(num_rsentries);
Expand All @@ -163,12 +177,13 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
int rollback_length =
std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);

// rollback kv cache
// Commit accepted tokens to the "verify_model", rollback kv cache
// in the "draft_model".
// NOTE: when number of small models is more than 1 (in the future),
// it is possible to re-compute prefill for the small models.
verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);
accepted_token_tree_leaf_nodes.push_back(accept_length - 1);
if (rollback_length > 0) {
models_[verify_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length);
// Draft model rollback minus one because verify uses one more token.
models_[draft_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1);
Expand All @@ -181,6 +196,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// - Slice and save hidden_states_for_sample
last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1);
}
models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(
verify_model_seq_internal_ids, accepted_token_tree_leaf_nodes);
if (!fully_accepted_rsentries.empty() &&
engine_config_->speculative_mode == SpeculativeMode::kEagle) {
// - Run a step of batch decode for requests whose drafts are fully accepted.
Expand Down
2 changes: 2 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ void FunctionTable::_InitFunctions() {
this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward");
this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward");
this->kv_cache_popn_func_ = get_global_func("vm.builtin.kv_state_popn");
this->kv_cache_commit_accepted_token_tree_nodes_func_ =
get_global_func("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes");
this->kv_cache_get_num_available_pages_func_ =
*tvm::runtime::Registry::Get("vm.builtin.attention_kv_cache_get_num_available_pages");
this->kv_cache_get_total_sequence_length_func_ =
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct FunctionTable {
PackedFunc kv_cache_begin_forward_func_;
PackedFunc kv_cache_end_forward_func_;
PackedFunc kv_cache_popn_func_;
PackedFunc kv_cache_commit_accepted_token_tree_nodes_func_;
PackedFunc kv_cache_get_num_available_pages_func_;
PackedFunc kv_cache_get_total_sequence_length_func_;
PackedFunc gpu_multinomial_from_uniform_func_;
Expand Down
23 changes: 19 additions & 4 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,16 @@ class ModelImpl : public ModelObj {
}

NDArray BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) final {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
}
CHECK_EQ(total_length, token_tree_parent_ptr.size());

NVTXScopedRange nvtx_scope("BatchVerify num_tokens=" + std::to_string(total_length));

Expand All @@ -471,7 +473,9 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand Down Expand Up @@ -512,14 +516,16 @@ class ModelImpl : public ModelObj {

ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) final {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
}
CHECK_EQ(total_length, token_tree_parent_ptr.size());
NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden num_tokens=" +
std::to_string(total_length));

Expand Down Expand Up @@ -548,7 +554,9 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
token_tree_parent_ptr_tuple);

// args: embeddings, logit_pos, kv_cache, params
ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_);
Expand Down Expand Up @@ -629,6 +637,13 @@ class ModelImpl : public ModelObj {
ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens);
}

void CommitAcceptedTokenTreeNodesToKVCache(
const std::vector<int64_t>& seq_ids,
const std::vector<int64_t>& accepted_leaf_indices) final {
ft_.kv_cache_commit_accepted_token_tree_nodes_func_(kv_cache_, IntTuple(seq_ids),
IntTuple(accepted_leaf_indices));
}

void EnableSlidingWindowForSeq(int64_t seq_id) final {
if (this->kind == KVStateKind::kNone) {
return;
Expand Down
20 changes: 18 additions & 2 deletions cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,36 @@ class ModelObj : public Object {
* \param embeddings The embedding of the input to be verified.
* \param seq_id The id of the sequence in the KV cache.
* \param lengths The length of each sequence to verify.
* \param token_tree_parent_ptr The parent pointers of the token tree.
* It's size is the sum of "lengths". It contains a batch of independent trees,
* one for each sequence. Parent being "-1" means the node is a root.
* \return The logits for the draft token for each sequence in the batch.
* \note The function runs for **every** sequence in the batch.
* That is to say, it does not accept "running a verify step for a subset
* of the full batch".
*/
virtual NDArray BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) = 0;
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) = 0;

/*!
* \brief Batch verify function. Input hidden_states are computed from
* input embeddings and previous hidden_states, output last hidden_states.
* \param hidden_states The hidden_states of the input to be verified.
* \param seq_id The id of the sequence in the KV cache.
* \param lengths The length of each sequence to verify.
* \param token_tree_parent_ptr The parent pointers of the token tree.
* It's size is the sum of "lengths". It contains a batch of independent trees,
* one for each sequence. Parent being "-1" means the node is a root.
* \return The hidden_states for the draft token for each sequence in the batch.
* \note The function runs for **every** sequence in the batch.
* That is to say, it does not accept "running a verify step for a subset
* of the full batch".
*/
virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) = 0;
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) = 0;

/*********************** KV Cache Management ***********************/

Expand Down Expand Up @@ -242,6 +250,14 @@ class ModelObj : public Object {
/*! \brief Pop out N pages from KV cache. */
virtual void PopNFromKVCache(int64_t seq_id, int num_tokens) = 0;

/*!
* \brief Commit the accepted token tree nodes to KV cache.
* The unaccepted token tree node will be removed from KV cache.
* This is usually used in the verification stage of speculative decoding.
*/
virtual void CommitAcceptedTokenTreeNodesToKVCache(
const std::vector<int64_t>& seq_ids, const std::vector<int64_t>& accepted_leaf_indices) = 0;

/*!
* \brief Enabling sliding window for the given sequence.
* It is a no-op if the model does not support sliding window.
Expand Down
56 changes: 56 additions & 0 deletions python/mlc_llm/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tvm.target import Target

from mlc_llm.op.position_embedding import llama_rope_with_position_map, rope_freq
from mlc_llm.op.tree_attn import tree_attn

from ..support.max_thread_check import (
check_thread_limits,
Expand Down Expand Up @@ -246,6 +247,8 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -350,6 +353,8 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -1570,3 +1575,54 @@ def copy_single_page(
pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

return copy_single_page


def _compact_kv_copy(num_heads, head_dim, dtype, target: Target):
tx = get_max_num_threads_per_block(target)

@T.prim_func
def compact_kv_copy(
var_pages: T.handle,
var_copy_length_indptr: T.handle,
var_copy_src_dst_pos: T.handle,
batch_size: T.int32,
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
total_copy_length = T.int32()
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype)
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
"int32",
elem_offset=copy_length_indptr_elem_offset,
)
copy_src_dst_pos = T.match_buffer(
var_copy_src_dst_pos,
(2, total_copy_length),
"int32",
elem_offset=copy_src_dst_pos_elem_offset,
)

with T.block("root"):
for bhd_o in T.thread_binding(
(batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
):
for bhd_i in T.thread_binding(tx, thread="threadIdx.x"):
b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim)
h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads
d: T.int32 = (bhd_o * tx + bhd_i) % head_dim
if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim:
for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]):
src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[
src_pos // 16, 0, h, src_pos % 16, d
]
pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[
src_pos // 16, 1, h, src_pos % 16, d
]

return compact_kv_copy
Loading

0 comments on commit c0c33a5

Please sign in to comment.