Skip to content

Commit

Permalink
[Runtime] Support PagedKVCache with tree attention (#17049)
Browse files Browse the repository at this point in the history
* [Runtime] Support PagedKVCache with tree attention

This PR introduces the tree attention to PagedKVCache. With this
feature, now the KV cache is ready for tree attention cases such as
speculative decoding trees.

This PR adds tree attention tests to test the correctness.

The changes in this PR to KVState interface are backward compatible.

* Update kv_state.cc

* Update kv_state.cc

---------

Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
  • Loading branch information
MasterJH5574 and tqchen authored Jun 1, 2024
1 parent 515c079 commit 31f4721
Show file tree
Hide file tree
Showing 5 changed files with 1,149 additions and 115 deletions.
15 changes: 14 additions & 1 deletion src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,26 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
.set_body_method<KVState>(&KVStateObj::ForkSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
.set_body_method<KVState>(&KVStateObj::BeginForward);
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 3 || args.size() == 4)
<< "KVState BeginForward only accepts 3 or 4 arguments";
KVState kv_state = args[0];
IntTuple seq_ids = args[1];
IntTuple append_lengths = args[2];
Optional<IntTuple> token_tree_parent_ptr{nullptr};
if (args.size() == 4) {
token_tree_parent_ptr = args[3].operator Optional<IntTuple>();
}
kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr);
});
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);

// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
Expand Down
15 changes: 14 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ class KVStateObj : public Object {
* in the model forward function.
* \param seq_ids The ids of the sequence to run in the incoming model forward.
* \param append_lengths The sequence lengths to run forward for for each sequence.
* \param token_tree_parent_ptr The parent idx array of the token trees. Its length
* is the sum of "append_lengths". Nullptr means the token tree of each sequence
* is a chain.
*/
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) = 0;
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& token_tree_parent_ptr = NullOpt) = 0;

/*!
* \brief Mark the start of the forward function.
Expand Down Expand Up @@ -142,6 +146,15 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) = 0;

/*!
* \brief Committed the accepted token tree nodes to KV cache.
* The commit will update the KV cache, by compacting the KV data and discard
* the KV data of rejected tokens.
* This is a mandatory step when the BeginForward is given with a token tree.
* \param leaf_indices The leaf token tree node index of each sequence.
*/
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;

/************** Attention **************/

/*!
Expand Down
Loading

0 comments on commit 31f4721

Please sign in to comment.