Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 93 additions & 33 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def test_prefill(hash_algo):
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]

# Check full block metadata
Expand All @@ -108,7 +110,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
Expand Down Expand Up @@ -140,7 +144,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]

# Although we only have 6 free blocks, we have 8 blocks in
Expand All @@ -161,7 +167,9 @@ def test_prefill(hash_algo):
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
Expand Down Expand Up @@ -197,7 +205,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks.blocks]

Expand Down Expand Up @@ -226,7 +236,9 @@ def test_prefill_plp():
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
Expand Down Expand Up @@ -259,7 +271,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
Expand Down Expand Up @@ -290,14 +304,18 @@ def test_decode():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]

# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
Expand All @@ -308,7 +326,9 @@ def test_decode():
# the preallocated block.
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None
Expand All @@ -328,7 +348,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial

# 3 blocks.
Expand All @@ -337,7 +359,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16

Expand All @@ -357,7 +381,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7

Expand All @@ -380,7 +406,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1

# Deallocate the block.
Expand All @@ -392,7 +420,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1

assert manager.block_pool.blocks[
Expand All @@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1

Expand All @@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2

Expand All @@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == block_size

blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
Expand All @@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3

# Free the blocks.
Expand All @@ -475,15 +512,19 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 4

# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert not blocks


Expand Down Expand Up @@ -581,14 +622,18 @@ def test_mm_prefix_caching():
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )

blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0

# The just completed block should have hashes with extra keys.
Expand Down Expand Up @@ -638,14 +683,18 @@ def test_cache_key_salting():
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None

blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0

# Now one more block that should not have extra keys.
Expand Down Expand Up @@ -691,15 +740,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]

# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
Expand All @@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)

# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
Expand All @@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
Expand All @@ -751,7 +805,9 @@ def test_reset_prefix_cache():
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]

# Failed to reset prefix cache because some blocks are not freed yet.
Expand Down Expand Up @@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
manager.reset_prefix_cache()

# Ensure prefix_cache_stats remains None
Expand Down Expand Up @@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block():

# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)

# New request with same tokens + Eagle enabled
Expand Down Expand Up @@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks():

# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)

# New request with Eagle enabled
Expand Down Expand Up @@ -928,7 +987,8 @@ def test_eagle_with_sliding_window():

# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
Expand Down
Loading