Skip to content

Commit

Permalink
[Runtime] Fix PagedKVCache for PopN and enhance tests (#17045)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the PagedKVCache which may happen when the
sequence removal order is not consistent with the reverse order
of sequence add/fork order. With this fix, the PagedKVCache now
supports removing sequences in any order without breaking.

This PR also adds an `empty` function to PagedKVCache to check if
the KV cache is empty. Right now this function is only used for test
purpose, where we check if everything in the KV cache is freed after
removing all sequences.
  • Loading branch information
MasterJH5574 authored May 30, 2024
1 parent 820f1b6 commit 1eac178
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length")
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class AttentionKVCacheObj : public KVStateObj {
public:
/************** Raw Info Query **************/

/*! \brief Check if the KV cache is empty. */
virtual bool Empty() const = 0;
/*!
* \brief Get the number of available pages in the KV cache.
* When the underlying KV cache implementation is not
Expand Down
49 changes: 32 additions & 17 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,14 @@ struct Sequence {
*/
int last_block_attn_sink_size = 0;

explicit Sequence(const std::vector<Block>& global_block_pool, int32_t last_block_idx) {
explicit Sequence(std::vector<Block>* global_block_pool, int32_t last_block_idx) {
++global_block_pool->at(last_block_idx).external_ref_cnt;
this->last_block_idx = last_block_idx;
int32_t block_ptr = last_block_idx;
// Go through each block in the sequence, sum up the length.
int depth = 0;
while (true) {
const Block& block = global_block_pool[block_ptr];
const Block& block = global_block_pool->at(block_ptr);
this->seq_length += block.seq_length;
++depth;
if (block.parent_idx == -1) {
Expand Down Expand Up @@ -965,17 +966,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
CHECK(seq_map_.find(seq_id) == seq_map_.end())
<< "The sequence \"" << seq_id << "\" is already in the KV cache.";
int32_t block_idx = GetFreeBlock();
seq_map_.insert({seq_id, Sequence(global_block_pool_, block_idx)});
seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)});
dirty_aux_data_device_ = true;
}

void RemoveSequence(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
int32_t block_idx = it->second.last_block_idx;
CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
<< "The sequence is currently referenced by other sequence and thus cannot be removed.";
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
// The block should have at least one reference, which comes from the sequence.
ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
// - Free pages in the last block.
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
Expand All @@ -985,7 +986,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
// - Decrease the external reference of the parent block.
if (block_idx != -1) {
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1);
--global_block_pool_[block_idx].external_ref_cnt;
}
seq_map_.erase(it);
Expand Down Expand Up @@ -1018,11 +1019,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (global_block_pool_[parent_block_idx].seq_length) {
// If parent is not empty, append a new block
if (parent_block_idx == parent_it->second.last_block_idx &&
global_block_pool_[parent_block_idx].seq_length) {
// To enable the parent sequence to continue decode after the fork,
// we add a new empty block at the end of the parent sequence.
// So the new decoded KV data will go into the new block.
int32_t new_parent_block_idx = GetFreeBlock();
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
global_block_pool_[new_parent_block_idx].external_ref_cnt = 1;
parent_it->second.last_block_idx = new_parent_block_idx;
}
} else {
Expand Down Expand Up @@ -1055,7 +1060,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
global_block_pool_[forked_block_idx].parent_idx;
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
global_block_pool_[parent_block_idx].external_ref_cnt = 1;
global_block_pool_[parent_block_idx].external_ref_cnt = 2;

// Move common leading pages to new parent block
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
Expand Down Expand Up @@ -1085,7 +1090,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)});
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
dirty_aux_data_device_ = true;
}

Expand Down Expand Up @@ -1119,7 +1124,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "A sequence cannot be enabled twice for sliding window.";

// Compute the total length of the prefix blocks of this sequence.
Block& last_block = global_block_pool_[it->second.last_block_idx];
const Block& last_block = global_block_pool_[it->second.last_block_idx];
int32_t prefix_length = it->second.seq_length - last_block.seq_length;
ICHECK_GE(prefix_length, 0);
// Since the prefix blocks cannot sliding, they are natural
Expand All @@ -1139,7 +1144,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "The sequence only has length " << it->second.seq_length
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
int32_t block_idx = it->second.last_block_idx;
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
// The block should have at least one reference, which comes from the sequence.
ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
if (n > global_block_pool_[block_idx].seq_length) {
n -= global_block_pool_[block_idx].seq_length;
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
Expand Down Expand Up @@ -1168,14 +1175,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}

if (n) {
int32_t temp_seq_id = -1 - seq_id;
// We use a temporary sequence id for fork.
// This temporary seq id will immediately end its effect outside this function.
int64_t temp_seq_id = -1 - seq_id;
CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
RemoveSequence(seq_id);
CHECK(seq_map_.find(seq_id) == seq_map_.end());
auto it = seq_map_.find(temp_seq_id);
seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
seq_map_.insert({seq_id, it->second});
seq_map_.erase(temp_seq_id);
}

Expand All @@ -1184,6 +1193,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

/************** Raw Info Query **************/

bool Empty() const final {
return seq_map_.empty() && //
free_block_idx_.size() == global_block_pool_.size() && //
free_page_ids_.size() == static_cast<size_t>(num_total_pages_);
}

int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); }

int32_t GetTotalSequenceLength() const final {
Expand Down Expand Up @@ -1565,8 +1580,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
int32_t block_idx = seq->last_block_idx;
Block& block = global_block_pool_[block_idx];
CHECK_GT(append_length, 0) << "Append with length 0 is not allowed.";
CHECK_EQ(block.external_ref_cnt, 0)
<< "The block is " << block.external_ref_cnt
CHECK_EQ(block.external_ref_cnt, 1)
<< "The block is " << block.external_ref_cnt - 1
<< "-time referenced by other blocks, thus cannot accept new KV values.";

// ==================== Reserve ====================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
fbegin_forward = None
fend_forward = None
fattention_with_fuse_qkv = None
fis_empty = None
fdebug_get_kv = None

ftranspose_append = None
Expand All @@ -71,7 +72,7 @@

def set_global_func(head_dim, dtype):
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq
global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fdebug_get_kv
global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged
global fattn_prefill_sliding_window, fattn_decode_sliding_window
global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page
Expand All @@ -89,6 +90,7 @@ def set_global_func(head_dim, dtype):
fattention_with_fuse_qkv = tvm.get_global_func(
"vm.builtin.attention_kv_cache_attention_with_fused_qkv"
)
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")

target = tvm.target.Target("cuda")
Expand Down Expand Up @@ -489,11 +491,19 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
for batch in operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)

for i in range(19, -1, -1):
num_sequence = 20
for i in range(num_sequence):
fremove_sequence(kv_cache, i)
cached_k.pop(i)
cached_v.pop(i)
verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v)
verify_cached_kv(
kv_cache,
seq_ids=list(range(i + 1, num_sequence)),
expected_k=cached_k,
expected_v=cached_v,
)

assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"


@tvm.testing.requires_gpu
Expand All @@ -510,14 +520,26 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config):
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v)

popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)]
popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
for seq_id, pop_length in popn_operations:
fpopn(kv_cache, seq_id, pop_length)
if pop_length != 0:
cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v)

num_sequence = 5
for seq_id in range(num_sequence):
fremove_sequence(kv_cache, seq_id)
verify_cached_kv(
kv_cache,
seq_ids=list(range(seq_id + 1, num_sequence)),
expected_k=cached_k,
expected_v=cached_v,
)

assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
Expand Down

0 comments on commit 1eac178

Please sign in to comment.