Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] PagedKVCache tree-attention integration #2487

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,19 @@ class BatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(total_verify_length);
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), total_verify_length);

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
NDArray logits =
models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids, verify_lengths);
NDArray logits = models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids,
verify_lengths, token_tree_parent_ptr);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
ICHECK_EQ(logits->ndim, 3);
ICHECK_EQ(logits->shape[0], 1);
Expand Down Expand Up @@ -138,7 +148,11 @@ class BatchVerifyActionObj : public EngineActionObj {
// by the draft model but not added into the draft model's KV cache.
// In this case, an additional batch decode step is needed for these requests.
std::vector<int64_t> fully_accepted_rsentries;
std::vector<int64_t> verify_model_seq_internal_ids;
std::vector<int64_t> accepted_token_tree_leaf_nodes;
fully_accepted_rsentries.reserve(num_rsentries);
verify_model_seq_internal_ids.reserve(num_rsentries);
accepted_token_tree_leaf_nodes.reserve(num_rsentries);

for (int i = 0; i < num_rsentries; ++i) {
const std::vector<SampleResult>& sample_results = sample_results_arr[i];
Expand All @@ -154,12 +168,13 @@ class BatchVerifyActionObj : public EngineActionObj {
accept_length);
int rollback_length =
std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);
// rollback kv cache
// Commit accepted tokens to the "verify_model", rollback kv cache
// in the "draft_model".
// NOTE: when number of small models is more than 1 (in the future),
// it is possible to re-compute prefill for the small models.
verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);
accepted_token_tree_leaf_nodes.push_back(accept_length - 1);
if (rollback_length > 0) {
models_[verify_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length);
// The last accepted token is not yet added into the draft model.
// Therefore, the rollback length for the draft model is one less.
models_[draft_model_id_]->PopNFromKVCache(
Expand All @@ -168,6 +183,8 @@ class BatchVerifyActionObj : public EngineActionObj {
fully_accepted_rsentries.push_back(i);
}
}
models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(
verify_model_seq_internal_ids, accepted_token_tree_leaf_nodes);

if (!fully_accepted_rsentries.empty()) {
// - Run a step of batch decode for requests whose drafts are fully accepted.
Expand Down
25 changes: 21 additions & 4 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,19 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(cum_verify_lengths.back());
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), cum_verify_lengths.back());

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
embeddings, request_internal_ids, verify_lengths);
embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states);
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
ICHECK_EQ(logits->ndim, 2);
Expand Down Expand Up @@ -141,7 +151,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// by the draft model but not added into the draft model's KV cache.
// In this case, an additional batch decode step is needed for these requests.
std::vector<int64_t> fully_accepted_rsentries;
std::vector<int64_t> verify_model_seq_internal_ids;
std::vector<int64_t> accepted_token_tree_leaf_nodes;
fully_accepted_rsentries.reserve(num_rsentries);
verify_model_seq_internal_ids.reserve(num_rsentries);
accepted_token_tree_leaf_nodes.reserve(num_rsentries);

std::vector<int> last_accepted_hidden_positions;
last_accepted_hidden_positions.reserve(num_rsentries);
Expand All @@ -163,12 +177,13 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
int rollback_length =
std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);

// rollback kv cache
// Commit accepted tokens to the "verify_model", rollback kv cache
// in the "draft_model".
// NOTE: when number of small models is more than 1 (in the future),
// it is possible to re-compute prefill for the small models.
verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);
accepted_token_tree_leaf_nodes.push_back(accept_length - 1);
if (rollback_length > 0) {
models_[verify_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[verify_model_id_]->internal_id, rollback_length);
// Draft model rollback minus one because verify uses one more token.
models_[draft_model_id_]->PopNFromKVCache(
rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1);
Expand All @@ -181,6 +196,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// - Slice and save hidden_states_for_sample
last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1);
}
models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(
verify_model_seq_internal_ids, accepted_token_tree_leaf_nodes);
if (!fully_accepted_rsentries.empty() &&
engine_config_->speculative_mode == SpeculativeMode::kEagle) {
// - Run a step of batch decode for requests whose drafts are fully accepted.
Expand Down
2 changes: 2 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ void FunctionTable::_InitFunctions() {
this->kv_cache_begin_forward_func_ = get_global_func("vm.builtin.kv_state_begin_forward");
this->kv_cache_end_forward_func_ = get_global_func("vm.builtin.kv_state_end_forward");
this->kv_cache_popn_func_ = get_global_func("vm.builtin.kv_state_popn");
this->kv_cache_commit_accepted_token_tree_nodes_func_ =
get_global_func("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes");
this->kv_cache_get_num_available_pages_func_ =
*tvm::runtime::Registry::Get("vm.builtin.attention_kv_cache_get_num_available_pages");
this->kv_cache_get_total_sequence_length_func_ =
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct FunctionTable {
PackedFunc kv_cache_begin_forward_func_;
PackedFunc kv_cache_end_forward_func_;
PackedFunc kv_cache_popn_func_;
PackedFunc kv_cache_commit_accepted_token_tree_nodes_func_;
PackedFunc kv_cache_get_num_available_pages_func_;
PackedFunc kv_cache_get_total_sequence_length_func_;
PackedFunc gpu_multinomial_from_uniform_func_;
Expand Down
23 changes: 19 additions & 4 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,16 @@ class ModelImpl : public ModelObj {
}

NDArray BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) final {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
}
CHECK_EQ(total_length, token_tree_parent_ptr.size());

NVTXScopedRange nvtx_scope("BatchVerify num_tokens=" + std::to_string(total_length));

Expand All @@ -471,7 +473,9 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand Down Expand Up @@ -512,14 +516,16 @@ class ModelImpl : public ModelObj {

ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) final {
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
}
CHECK_EQ(total_length, token_tree_parent_ptr.size());
NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden num_tokens=" +
std::to_string(total_length));

Expand Down Expand Up @@ -548,7 +554,9 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
token_tree_parent_ptr_tuple);

// args: embeddings, logit_pos, kv_cache, params
ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_);
Expand Down Expand Up @@ -629,6 +637,13 @@ class ModelImpl : public ModelObj {
ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens);
}

void CommitAcceptedTokenTreeNodesToKVCache(
const std::vector<int64_t>& seq_ids,
const std::vector<int64_t>& accepted_leaf_indices) final {
ft_.kv_cache_commit_accepted_token_tree_nodes_func_(kv_cache_, IntTuple(seq_ids),
IntTuple(accepted_leaf_indices));
}

void EnableSlidingWindowForSeq(int64_t seq_id) final {
if (this->kind == KVStateKind::kNone) {
return;
Expand Down
20 changes: 18 additions & 2 deletions cpp/serve/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,36 @@ class ModelObj : public Object {
* \param embeddings The embedding of the input to be verified.
* \param seq_id The id of the sequence in the KV cache.
* \param lengths The length of each sequence to verify.
* \param token_tree_parent_ptr The parent pointers of the token tree.
* It's size is the sum of "lengths". It contains a batch of independent trees,
* one for each sequence. Parent being "-1" means the node is a root.
* \return The logits for the draft token for each sequence in the batch.
* \note The function runs for **every** sequence in the batch.
* That is to say, it does not accept "running a verify step for a subset
* of the full batch".
*/
virtual NDArray BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) = 0;
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) = 0;

/*!
* \brief Batch verify function. Input hidden_states are computed from
* input embeddings and previous hidden_states, output last hidden_states.
* \param hidden_states The hidden_states of the input to be verified.
* \param seq_id The id of the sequence in the KV cache.
* \param lengths The length of each sequence to verify.
* \param token_tree_parent_ptr The parent pointers of the token tree.
* It's size is the sum of "lengths". It contains a batch of independent trees,
* one for each sequence. Parent being "-1" means the node is a root.
* \return The hidden_states for the draft token for each sequence in the batch.
* \note The function runs for **every** sequence in the batch.
* That is to say, it does not accept "running a verify step for a subset
* of the full batch".
*/
virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) = 0;
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) = 0;

/*********************** KV Cache Management ***********************/

Expand Down Expand Up @@ -242,6 +250,14 @@ class ModelObj : public Object {
/*! \brief Pop out N pages from KV cache. */
virtual void PopNFromKVCache(int64_t seq_id, int num_tokens) = 0;

/*!
* \brief Commit the accepted token tree nodes to KV cache.
* The unaccepted token tree node will be removed from KV cache.
* This is usually used in the verification stage of speculative decoding.
*/
virtual void CommitAcceptedTokenTreeNodesToKVCache(
const std::vector<int64_t>& seq_ids, const std::vector<int64_t>& accepted_leaf_indices) = 0;

/*!
* \brief Enabling sliding window for the given sequence.
* It is a no-op if the model does not support sliding window.
Expand Down
56 changes: 56 additions & 0 deletions python/mlc_llm/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tvm.target import Target

from mlc_llm.op.position_embedding import llama_rope_with_position_map, rope_freq
from mlc_llm.op.tree_attn import tree_attn

from ..support.max_thread_check import (
check_thread_limits,
Expand Down Expand Up @@ -246,6 +247,8 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -350,6 +353,8 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -1570,3 +1575,54 @@ def copy_single_page(
pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

return copy_single_page


def _compact_kv_copy(num_heads, head_dim, dtype, target: Target):
tx = get_max_num_threads_per_block(target)

@T.prim_func
def compact_kv_copy(
var_pages: T.handle,
var_copy_length_indptr: T.handle,
var_copy_src_dst_pos: T.handle,
batch_size: T.int32,
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
total_copy_length = T.int32()
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype)
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
"int32",
elem_offset=copy_length_indptr_elem_offset,
)
copy_src_dst_pos = T.match_buffer(
var_copy_src_dst_pos,
(2, total_copy_length),
"int32",
elem_offset=copy_src_dst_pos_elem_offset,
)

with T.block("root"):
for bhd_o in T.thread_binding(
(batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
):
for bhd_i in T.thread_binding(tx, thread="threadIdx.x"):
b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim)
h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads
d: T.int32 = (bhd_o * tx + bhd_i) % head_dim
if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim:
for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]):
src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[
src_pos // 16, 0, h, src_pos % 16, d
]
pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[
src_pos // 16, 1, h, src_pos % 16, d
]

return compact_kv_copy
Loading
Loading