diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a4a571b180c6..e2f8fd1999c4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -496,8 +496,7 @@ def test_allocate_with_lookahead(): # Test case 1: Requires additional lookahead tokens kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=0) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, @@ -507,25 +506,19 @@ def test_allocate_with_lookahead(): # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) - # num_preallocate_blocks = 4 // 4 - 2 // 4 = 1 + max_model_len=100) # required_blocks = ceil((3 + 2) /4) = 2 - # total_blocks = 1 + 2 = 3 blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks) == 3 + assert len(blocks) == 2 # Test case 3: With precomputed blocks - # num_preallocate_blocks = 4 // 4 - 4 // 4 = 0 # required_blocks = ceil((3 + 4) / 4) = 2 - # total_blocks = 0 + 2 = 2 kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100, - num_preallocate_tokens=4) + max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_tokens=3, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 669c04236925..1b238d47c03a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,7 +8,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, @@ -61,7 +61,6 @@ def test_prefill(hash_algo): max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, - num_preallocate_tokens=16, ) # choose the hash function according to the parameter @@ -80,7 +79,7 @@ def test_prefill(hash_algo): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -92,8 +91,8 @@ def test_prefill(hash_algo): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -107,12 +106,12 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -120,14 +119,14 @@ def test_prefill(hash_algo): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) @@ -139,29 +138,29 @@ def test_prefill(hash_algo): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [8, 9] + assert [b.block_id for b in blocks] == [6] - # Although we only have 5 free blocks, we have 8 blocks in + # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.block_pool.free_block_queue.num_free_blocks == 5 + assert manager.block_pool.free_block_queue.num_free_blocks == 6 assert all([ b.ref_cnt == 0 for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) == 5 + ]) == 6 manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 9)) + req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1] + assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -178,7 +177,6 @@ def test_prefill_plp(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # the default hash function is hash hash_fn = hash @@ -197,7 +195,7 @@ def test_prefill_plp(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks] # Check full block metadata @@ -210,8 +208,8 @@ def test_prefill_plp(): assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value - # Check partial/preallocated block metadata - for block_id in (4, 5): + # Check partial block metadata + for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -226,12 +224,12 @@ def test_prefill_plp(): assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [6, 7] + assert [b.block_id for b in blocks] == [5] for block in computed_blocks: assert block.ref_cnt == 2 - # At this point, we should have 3 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 3 + # At this point, we should have 5 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) @@ -239,14 +237,14 @@ def test_prefill_plp(): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (8, 9, 10)] - # [unique_req0 (5, 4)] - # [unique_req1 (7, 6)] + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks @@ -262,7 +260,7 @@ def test_prefill_plp(): block_ids = [b.block_id for b in blocks] # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks] == req0_block_hashes - assert block_ids != [1, 2, 3, 4, 5] + assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -277,7 +275,6 @@ def test_decode(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Complete 3 blocks (48 tokens) @@ -291,7 +288,7 @@ def test_decode(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -299,28 +296,18 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is None + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None - # Append slots without allocating a new block, but start using the - # preallocated block. + # Append slots with allocating a new block. req0.num_computed_tokens = 59 - # 6 tokens to fill the previous block, and 10 tokens to fill + # 9 tokens to fill the previous block, and 10 tokens to fill # the preallocated block. - for _ in range(5 + 10): + for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 15) - assert new_blocks is not None and len(new_blocks) == 0 + new_blocks = manager.allocate_slots(req0, 19) + assert new_blocks is not None and len(new_blocks) == 1 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - - # Append slots with allocating a new block. - req0.num_computed_tokens = 74 - # 6 tokens to fill the previous block, and 10 tokens to fill - # the preallocated block. - for _ in range(6 + 11): - req0.append_output_token_ids(12) - new_blocks = manager.allocate_slots(req0, 17) - # Plus one preallocated block. - assert new_blocks is not None and len(new_blocks) == 2 + assert manager.req_to_blocks[req0.request_id][-1].block_hash is None def test_evict(): @@ -328,7 +315,6 @@ def test_evict(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) last_token_id = 5 * 16 + 7 @@ -337,7 +323,7 @@ def test_evict(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) - assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated + assert len(blocks) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, @@ -349,7 +335,8 @@ def test_evict(): assert len(blocks) == 3 # 3 full blocks last_token_id += 3 * 16 - assert manager.block_pool.free_block_queue.num_free_blocks == 0 + # 10 - (6 + 3) == 1 + assert manager.block_pool.free_block_queue.num_free_blocks == 1 manager.free(req0) manager.free(req1) @@ -357,7 +344,7 @@ def test_evict(): assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8] + ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) @@ -365,8 +352,8 @@ def test_evict(): assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [7, 6] - assert manager.block_pool.free_block_queue.num_free_blocks == 6 + assert [b.block_id for b in blocks] == [10] + assert manager.block_pool.free_block_queue.num_free_blocks == 7 def test_hash_block_correct_reuse(): @@ -379,7 +366,6 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate 1 block and cache it. @@ -416,7 +402,6 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Allocate a block and cache it. @@ -465,7 +450,6 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, - num_preallocate_tokens=0, ) req1 = make_request("1", list(range(10))) # 2 blocks and some more @@ -496,40 +480,6 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8))) -@pytest.mark.parametrize("block_size", [4]) -def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): - """ - This tests that the preallocated blocks are correctly added. - """ - manager = KVCacheManager( - make_kv_cache_config(block_size, 11), - max_model_len=8192, - enable_caching=True, - num_preallocate_tokens=num_preallocate_tokens, - ) - num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) - - req = make_request("0", list(range(block_size * 30))) - computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks - assert num_computed_tokens == 0 - # Just ask for 1 block. - blocks = manager.allocate_slots(req, block_size, computed_blocks) - req.num_computed_tokens = block_size - assert len(blocks) == 1 + num_preallocated_blocks - - # Assume all computed, only when num_preallocate_tokens > 0, we need to - # consume the previously preallocated blocks. - if num_preallocated_blocks > 0: - manager.allocate_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) - - # Append 1 block. - blocks = manager.allocate_slots(req, block_size) - assert len(blocks) == 1 + num_preallocated_blocks - - @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_cache_blocks(hash_fn): """ @@ -588,7 +538,6 @@ def test_mm_prefix_caching(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=16, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -626,7 +575,7 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -667,7 +616,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -721,7 +669,6 @@ def test_reset_prefix_cache(): make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, - num_preallocate_tokens=0, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f173344344f9..560a60a81446 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -804,20 +804,17 @@ def _assert_right_kv_cache_manager( """Check whether KVCacheManager is correct after allocate.""" # Make sure the request stats are right. - EXPECTED_ACTUAL_BLOCKS = num_tokens // block_size - EXPECTED_TOTAL_BLOCKS = (EXPECTED_ACTUAL_BLOCKS + - scheduler.kv_cache_manager.num_preallocate_blocks) + EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_ACTUAL_BLOCKS) + EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes) == EXPECTED_ACTUAL_BLOCKS + assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. - BLOCKS_PER_REQ = (num_tokens / block_size + - scheduler.kv_cache_manager.num_preallocate_blocks) + BLOCKS_PER_REQ = num_tokens / block_size assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == num_total_blocks - num_requests * BLOCKS_PER_REQ) @@ -1052,7 +1049,6 @@ def test_kv_connector_handles_preemption(): block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, ) - scheduler.kv_cache_manager.num_preallocate_blocks = 0 NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c3c83baf5129..354300d3c2fe 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -25,7 +25,6 @@ def __init__( max_model_len: int, enable_caching: bool = True, caching_hash_algo: str = "builtin", - num_preallocate_tokens: int = 64, log_stats: bool = False, ) -> None: assert len(kv_cache_config.kv_cache_groups) == 1, ( @@ -42,22 +41,8 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - # NOTE(woosuk): To avoid frequent block allocation, we preallocate some - # blocks for each request. For example, when a request reaches the end - # of its block table, we preallocate N blocks in advance. This way, we - # reduce the overhead of updating free_block_ids and ref_cnts for each - # request every step (at the cost of some memory waste). - # NOTE(woosuk): This is different from the "lookahead" slots since this - # does not guarantee that the request always has N empty blocks. After - # the request gets N empty blocks, it starts to use the blocks without - # further allocation. When it uses up all the N empty blocks, it gets - # N new empty blocks. - self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, - self.block_size) self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, @@ -256,13 +241,9 @@ def allocate_slots( # No new block is needed. new_blocks = [] else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_preallocate_blocks = max( - 0, self.num_preallocate_blocks - - num_lookahead_tokens // self.block_size) + # Get new blocks from the free block pool. num_new_blocks = min( - num_new_blocks + num_preallocate_blocks, + num_new_blocks, self.block_pool.get_num_free_blocks(), # Should not exceed the maximum number of blocks per request. # This is especially because the block table has the shape diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 16efc42f212e..5adcdde5bcd7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -358,8 +358,11 @@ def schedule(self) -> SchedulerOutput: new_encoder_budget = encoder_budget new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens + num_external_tokens, - computed_blocks) + request, + num_new_tokens + num_external_tokens, + computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is None: # The request cannot be scheduled. break