From 35b3d95e4c11a2683d76eb6539513f009a97c1ee Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Mon, 30 Jun 2025 12:23:08 -0700 Subject: [PATCH 1/3] test: KVBM vLLM python tests (#1463) --- .../tests/test_kvbm_vllm_integration.py | 817 ++++++++++++++++++ 1 file changed, 817 insertions(+) create mode 100644 lib/bindings/python/tests/test_kvbm_vllm_integration.py diff --git a/lib/bindings/python/tests/test_kvbm_vllm_integration.py b/lib/bindings/python/tests/test_kvbm_vllm_integration.py new file mode 100644 index 0000000000..04ee8b9d7e --- /dev/null +++ b/lib/bindings/python/tests/test_kvbm_vllm_integration.py @@ -0,0 +1,817 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import Optional + +import pytest +import torch + +try: + from vllm.multimodal.inputs import MultiModalKwargs + from vllm.sampling_params import SamplingParams + from vllm.v1.core.kv_cache_manager import Request + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + ) + + VLLM_NOT_AVAILABLE = False +except ImportError: + VLLM_NOT_AVAILABLE = True + +from dynamo.llm import BlockManager +from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager + + +def new_kv_cache_manager( + worker_id: int = 0, + num_layer: int = 1, + outer_dim: int = 1, + page_size: int = 16, + inner_dim: int = 1, + device_id: int = 0, + num_blocks: int = 11, +): + """ + Creates a new KVBM cache manager. + + Returns: + KvbmCacheManager: The KVBM cache manager. + """ + return KvbmCacheManager( + BlockManager( + worker_id, + num_layer, + outer_dim, + page_size, + inner_dim, + "FP32", # dtype + num_blocks, # host_num_blocks + num_blocks, # device_num_blocks + device_id, + ) + ) + + +def make_request( + request_id, + prompt_token_ids, + mm_positions=None, + mm_hashes=None, + prompt_logprobs: Optional[int] = None, + cache_salt: Optional[str] = None, +): + if mm_positions is None: + multi_modal_inputs = None + else: + multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) + + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_inputs=multi_modal_inputs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), + eos_token_id=100, + arrival_time=0, + lora_request=None, + cache_salt=cache_salt, + ) + + +def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + FullAttentionSpec(block_size, 1, 1, torch.float32, False), + ) + ], + ) + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_prefill(): + """ + Tests the KvbmCacheManager's prefill functionality. + """ + manager = new_kv_cache_manager() + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(16)] + + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) + + # Step 1: Initial allocation - no computed blocks yet + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + + # Step 2: Allocate slots for the request + blocks_req0 = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks) * 16, computed_blocks + ) + + for block in blocks_req0.blocks: + assert block._block_hash is None + + # Verify allocation was successful + block_ids = manager.get_block_ids(req0.request_id) + assert len(block_ids) == 1 # One sequence in the request + assert len(block_ids[0]) == 4 # 4 blocks allocated (3 complete + 1 partial) + + # Step 3: Simulate model execution by updating the request's computed tokens + req0.append_output_token_ids(100) + req0.num_computed_tokens = 55 + + _ = manager.allocate_slots(req0, num_new_tokens=1) + + # Step 5: Create a new request with the same prefix plus one token + unique_token_ids = [3] * 4 + req1 = make_request("1", common_token_ids + unique_token_ids) + + # Step 8: Check for computed blocks - should find the common prefix + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks) == 3 + assert num_computed_tokens == len(computed_blocks.blocks) * 16 + + for block in computed_blocks.blocks: + assert block._block_hash is not None + print(block) + + # Clean up + manager.free_block_hashes(req0) + manager.free_block_hashes(req1) + + # TODO(oandreeva): + # Currently need a delay here, since there's an intermittency. + # Not all blocks become available for KVBM to assign to the new request. + # Heppens with `free` and `free_block_hashes` calls. + time.sleep(0.5) + + # Cache miss and eviction. + req3 = make_request("3", [24] * (16 * 11)) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks_req3 = manager.allocate_slots( + req3, 16 * 11, len(computed_blocks.blocks) * 16, computed_blocks + ) + + assert len(blocks_req3.blocks) == 11 + for block, expected_block_id in zip( + blocks_req3.blocks, [4, 5, 6, 7, 8, 9, 10, 3, 2, 1, 0] + ): + assert block._block_hash is None + assert block.block_id == expected_block_id + + time.sleep(0.5) + + +@pytest.mark.skip(reason="KVBM needs to support reset_prefix_cache") +def test_prefill_plp(): + """Test prefill with APC and some prompt logprobs (plp) requests. + + 1. Schedule plp request and validate APC block allocation + 2. Schedule non-plp request and validate blocks + 3. Schedule plp request; no hit should occur; validate blocks + """ + manager = new_kv_cache_manager() + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(16)] + + # Request #0 is a prompt logprobs request + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids, prompt_logprobs=5) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + # assert len(manager.req_to_block_hashes[req0.request_id]) == 0 + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks) * 16, computed_blocks + ) + + # assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [[0, 1, 2, 3]] + req0_block_hashes = [b.block_hash for b in blocks.blocks] + + # Step 3: Simulate model execution by updating the request's computed tokens + req0.append_output_token_ids(100) + req0.num_computed_tokens = 55 + + _ = manager.allocate_slots(req0, num_new_tokens=1) + + # Check full block metadata + """ + parent_block_hash = None + for block_id in (1, 2, 3): + block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, + block_tokens) + assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + parent_block_hash = block_hash.hash_value + + # Check partial block metadata + for block_id in (4, ): + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + """ + + # Request #1 is a non-prompt-logprobs request: + # Cache hit in the common prefix when the original block is still in use. + # Incomplete 1 block (5 tokens) + unique_token_ids = [3] * 5 + req1 = make_request("1", common_token_ids + unique_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + # assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + # assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [[0, 1, 2]] + assert num_computed_tokens == 3 * 16 + num_new_tokens = 53 - 3 * 16 + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks + ) + # assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [[4]] + # for block in computed_blocks.blocks: + # assert block.ref_cnt == 2 + + # At this point, we should have 5 free blocks left. + # assert manager.block_pool.free_block_queue.num_free_blocks == 5 + + manager.free(req0) + manager.free(req1) + + """ + # All blocks should be available. + assert manager.block_pool.free_block_queue.num_free_blocks == 10 + # The order should be + # [unallocated (6, 7, 8, 9, 10)] + # [unique_req0 (4)] + # [unique_req1 (5)] + # [common (3, 2, 1)] + assert [ + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] + """ + + # Request #2 is a prompt-logprobs request: + # NO cache hit in the common prefix; duplicates request #0 cached blocks + unique_token_ids = [3] * 6 + req2 = make_request("2", common_token_ids + unique_token_ids, prompt_logprobs=5) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + # assert len(manager.req_to_block_hashes[req2.request_id]) == 0 + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req2, 55, len(computed_blocks.blocks) * 16, computed_blocks + ) + block_ids = blocks.get_block_ids() + # Duplicate cached blocks have different ids but same hashes vs request #0 + assert [b.block_hash for b in blocks.blocks] == req0_block_hashes + assert block_ids != [[1, 2, 3, 4]] + + # Request #2 block hashes are valid since request #0 hashes are. + # Check block reference counts. + for block_id in block_ids[0]: + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + manager.free(req2) + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_decode(): + manager = new_kv_cache_manager() + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(16)] + + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + req0 = make_request("0", common_token_ids + unique_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks) * 16, computed_blocks + ) + # assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [[0, 1, 2, 3]] + # Append slots without allocating a new block. + req0.num_computed_tokens = 55 + for _ in range(4): + req0.append_output_token_ids(8) + + new_blocks = manager.allocate_slots( + req0, 4, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert new_blocks is not None and len(new_blocks.blocks) == 0 + + # NOTE(): There's no way to access the current active non-registered block + # from the python bindings. + # assert manager.single_type_manager.req_to_blocks[ + # req0.request_id][-1].block_hash is None + + # Append slots with allocating a new block. + req0.num_computed_tokens = 59 + # 9 tokens to fill the previous block, and 10 tokens to fill + # the preallocated block. + for _ in range(9 + 10): + req0.append_output_token_ids(7) + + print(len(computed_blocks.blocks)) + new_blocks = manager.allocate_slots( + req0, 19, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert new_blocks is not None and len(new_blocks.blocks) == 1 + assert new_blocks.blocks[-1].block_hash is None + + req0.num_computed_tokens = 78 + req0.append_output_token_ids(100) + + # The following is required for KVBM to register the block with id=3 + _ = manager.allocate_slots( + req0, 1, len(computed_blocks.blocks) * 16, computed_blocks + ) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + # assert manager.single_type_manager.req_to_blocks[ + # req0.request_id][-2].block_hash is not None + # assert manager.single_type_manager.req_to_blocks[ + # req0.request_id][-1].block_hash is None + assert computed_blocks.blocks[-1].block_id == 3 + assert computed_blocks.blocks[-1].block_hash is not None + + # Clean up + manager.free_block_hashes(req0) + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_evict(): + manager = new_kv_cache_manager() + used_blocks = set() + + last_token_id = 5 * 16 + 7 + req0 = make_request("0", list(range(last_token_id))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, 5 * 16 + 7, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert len(blocks.blocks) == 6 # 5 full + 1 partial + used_blocks.update(blocks.get_block_ids()[0]) + + req0.append_output_token_ids(100) + req0.num_computed_tokens = 5 * 16 + 7 + manager.allocate_slots(req0, 1, len(computed_blocks.blocks) * 16, computed_blocks) + + req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16 - 1))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req1, 3 * 16, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert ( + len(blocks.blocks) == 3 + ) # 2 full blocks and 1 partial (15 tokens) 1 more will be added during allocate_slots + last_token_id += 3 * 16 - 1 + used_blocks.update(blocks.get_block_ids()[0]) + + # 10 - (6 + 3) == 1 + assert len(used_blocks) == 6 + 3 + + req1.append_output_token_ids(100) + req1.num_computed_tokens = 3 * 16 - 1 + blocks = manager.allocate_slots( + req1, 1, len(computed_blocks.blocks) * 16, computed_blocks + ) + + manager.free(req0) + manager.free(req1) + # Can't access the free blocks queue from the python bindings. + # assert manager.block_pool.free_block_queue.num_free_blocks == 10 + # assert [ + # b.block_id + # for b in manager.block_pool.free_block_queue.get_all_free_blocks() + # ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] + + # Touch the first 2 blocks. + req2 = make_request("2", list(range(2 * 16 + 3))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + # assert computed_blocks.get_block_ids() == [[1, 2]] + assert computed_blocks.get_block_ids() == [[0, 1]] + assert num_computed_tokens == 2 * 16 + blocks = manager.allocate_slots( + req2, 3, len(computed_blocks.blocks) * 16, computed_blocks + ) + + assert blocks.get_block_ids() == [[9]] + # Can't access the free blocks queue from the python bindings. + # assert manager.block_pool.free_block_queue.num_free_blocks == 7 + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_hash_block_correct_reuse(): + """ + This tests when a previously cached block is reused as a new block, + its hash metadata should be correctly reset. + """ + block_size = 16 + manager = new_kv_cache_manager(num_blocks=2) + + # Allocate 1 block and cache it. + num_tokens = block_size + req = make_request("0", list(range(num_tokens))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert len(blocks.blocks) == 1 + for t in range(5): + req.append_output_token_ids(100) + req.num_computed_tokens = num_tokens + blocks = manager.allocate_slots( + req, 5, len(computed_blocks.blocks) * 16, computed_blocks + ) + + computed_blocks, _ = manager.get_computed_blocks(req) + assert computed_blocks.blocks[0].block_hash is not None + assert computed_blocks.blocks[0].block_id == 0 + + # Deallocate the block. + manager.free(req) + # Note(oandreeva): need to fix this in the kvbm core to not depend on time.sleep() + time.sleep(2) + + # Allocate new blocks, last one is partial not full, make sure hash info on the + # blocks are cleared. + # KVBM will allocate block 1 first, then block 0. Need to verify, + # that block's 0 hash is cleared + req = make_request("1", list(range(256, 256 + 2 * num_tokens - 1))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req, 2 * num_tokens - 1, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert len(blocks.blocks) == 2 + + assert blocks.blocks[1].block_id == 0 + assert blocks.blocks[1].block_hash is None + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_computed_blocks_not_evicted(): + """ + Test that the computed blocks are not evicted when getting new blocks + for a request if there are any other free blocks. + """ + block_size = 16 + manager = new_kv_cache_manager(num_blocks=3) + + # Allocate a block and cache it. + num_tokens = block_size * 1 + req0 = make_request("0", list(range(num_tokens))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert len(blocks.blocks) == 1 + # assert blocks.blocks[0].block_id == 1 + assert blocks.blocks[0].block_id == 0 + + # Allocate another block. + req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req1, num_tokens, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert len(blocks.blocks) == 1 + # assert blocks.blocks[0].block_id == 2 + assert blocks.blocks[0].block_id == 1 + + # Need to simulate the forward pass to get blocks registered + req0.append_output_token_ids(100) + req0.num_computed_tokens = num_tokens + _ = manager.allocate_slots( + req0, 1, len(computed_blocks.blocks) * 16, computed_blocks + ) + + req1.append_output_token_ids(100) + req1.num_computed_tokens = num_tokens + _ = manager.allocate_slots( + req1, 1, len(computed_blocks.blocks) * 16, computed_blocks + ) + + # Free the blocks. + manager.free(req0) + manager.free(req1) + # Note(oandreeva): need to fix this in the kvbm core to not depend on time.sleep() + time.sleep(2) + + # Now if we have a cache hit on the block_id 0, we should evict the block_id 1 + # cached block rather than the first one. + req2 = make_request("2", list(range(num_tokens * 3))) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks.blocks) == 1 + # assert computed_blocks.blocks[0].block_id == 1 + assert computed_blocks.blocks[0].block_id == 0 + assert num_computed_tokens == block_size + + # Allocate should return a free block with id 2 first, and then block with id 1 + # which was evicted. + blocks = manager.allocate_slots( + req2, + num_tokens * 3 - num_computed_tokens, + len(computed_blocks.blocks) * 16, + computed_blocks, + ) + assert len(blocks.blocks) == 2 + assert blocks.blocks[0].block_id == 2 + assert blocks.blocks[1].block_id == 1 + + +def _test_basic_prefix_caching_disabled(): + """ + Currently, KVBM does not support `enable_caching` or setting it to False to disable prefix caching. + """ + pass + + +# @pytest.mark.parametrize("hash_fn", [sha256, hash]) +def _test_cache_blocks(hash_fn): + """ + Hashing is done by KVBM and tested by the core library. + """ + pass + + +def _test_mm_prefix_caching(): + """ + KVBM currently does not support multi-modal prefix caching. + This tests that the multi-modal prefix caching is correct. + """ + pass + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_cache_key_salting(): + """ + This tests that cache salts are applied during hashing and the cache + is separated cache as expected. + + The test is mostly the same as the one for vLLM's native KV cache manager. + The only difference is for KVBM we don't need a `BlockHashType` object on python + side, thus we don't check the value of the salt. We test the salt-ing + functionality by validating cache miss and cache hit with different salts. + """ + block_size = 16 + manager = new_kv_cache_manager() + + # 3 complete blocks and an incomplete block with 11 tokens. + common_token_ids = [i for i in range(3) for _ in range(block_size)] + token_ids = common_token_ids + [3] * 11 + req0 = make_request("0", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + """ + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt1", ) + assert block_hashes[1].extra_keys is None + assert block_hashes[2].extra_keys is None + """ + + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert blocks.get_block_ids() == [[0, 1, 2, 3]] # [[1, 2, 3, 4]] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks) * 16, computed_blocks + ) + assert new_blocks is not None and len(new_blocks.blocks) == 0 + print(new_blocks) + """ + # Now one more block that should not have extra keys. + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys is None + """ + # Test cache hit with a new request that has the same salt. + token_ids = common_token_ids + [4] * 11 + req1 = make_request("1", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + # Should match only a prefix of 3 blocks. + assert len(computed_blocks.blocks) == 3 + assert num_computed_tokens == 3 * block_size + + # Test cache miss with same content but different salt. + token_ids = common_token_ids + [4] * 11 + req2 = make_request("2", token_ids, cache_salt="salt2") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks.blocks) == 0 + assert num_computed_tokens == 0 + """ + block_hashes = manager.req_to_block_hashes[req2.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt2", ) + """ + + +@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +def test_prefill_not_enough_free_blocks_with_computed_blocks(): + """ + This is a unit test that tests the correctness of the allocate_slots + when there is not enough free blocks. Specifically, when a request + has computed blocks but cannot be allocated due to not enough free blocks, + the computed blocks should not be touched. + """ + block_size = 16 + manager = new_kv_cache_manager() + + # Complete 3 blocks (48 tokens) + # | Common-0 | Common-1 | Common-2 | ... | + common_token_ids = [i for i in range(3) for _ in range(16)] + req0 = make_request("0", common_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + manager.allocate_slots(req0, 48, len(computed_blocks.blocks) * 16, computed_blocks) + # block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] + block_part0 = len(manager.get_block_ids(req0.request_id)[0]) + + # Simulate model execution by updating the request's computed tokens + req0.append_output_token_ids(100) + req0.num_computed_tokens = 48 + _ = manager.allocate_slots(req0, num_new_tokens=1) + + # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | + req1 = make_request("1", common_token_ids * 2) # Double the common tokens + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert ( + len(computed_blocks.blocks) == block_part0 + ) # First 3 blocks are computed from req0 + assert num_computed_tokens == 3 * 16 # 3 blocks * 16 tokens per block + manager.allocate_slots(req1, 48, num_computed_tokens, computed_blocks) + # block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] + block_part1 = len(manager.get_block_ids(req1.request_id)[0]) + + # Simulate forward pass for req1 to compute all 6 blocks + req1.append_output_token_ids(100) + req1.num_computed_tokens = 96 + _ = manager.allocate_slots(req1, num_new_tokens=1) + + # Free req1 to make its blocks available + manager.free(req1) + + # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | + # | Req1-5(F)| Req2-0 | Req2-1 | ... | + req2 = make_request("2", [7] * block_size * 2) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert not computed_blocks.blocks + assert num_computed_tokens == 0 + manager.allocate_slots( + req2, block_size * 2, len(computed_blocks.blocks) * 16, computed_blocks + ) + + # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, + # but it cannot be allocated due to insufficient free blocks (2). + # In this case, the ref_cnt of the computed blocks should not be changed. + req3 = make_request("3", common_token_ids * 2) # Use same tokens as req1 + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) + assert len(computed_blocks.blocks) == block_part1 # Should find 6 computed blocks + assert num_computed_tokens == 6 * 16 # 6 blocks * 16 tokens per block + + # Req3 cannot be allocated due to insufficient free blocks + # DYN LOG print: + # DEBUG dynamo_llm::block_manager::pool::state: not enough blocks available, requested: 3, available: 2 + assert ( + manager.allocate_slots( + req3, 48, len(computed_blocks.blocks) * 16, computed_blocks + ) + is None + ) + + # Clean up + manager.free_block_hashes(req0) + manager.free_block_hashes(req2) + manager.free_block_hashes(req3) + + +def _test_reset_prefix_cache(): + """ + `reset_prefix_cache` is currently not implemented. + It returns False every time it is called + """ + pass + + +def _test_prefix_cache_stats_disabled(): + """ + `reset_prefix_cache` is currently not implemented. + It returns False every time it is called + """ + pass + + +# @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) +def _test_kv_cache_events(blocks_to_cache: int): + """ + KVBM's Event Manager is responsible for emitting events. + Currently tested separately as a part of dynamo integration tests. + """ + pass + + +def _test_eagle_enabled_removes_last_block(): + """NOTE: KVBM does not support spec decoding at the moment. + Verify Eagle does NOT remove blocks when request + length is divisible by block size.""" + pass + + +def _test_eagle_with_partial_blocks(): + """NOTE: KVBM does not support spec decoding at the moment. + Test Eagle behavior with requests containing partial blocks.""" + pass + + +def _test_eagle_with_sliding_window(): + """NOTE: KVBM does not support spec decoding at the moment. + Test Eagle behavior with sliding window.""" + pass + + +def test_kvbm_wrong_blocks_provided(): + """ + Tests that providing wrong blocks to allocate_slots results in an error. + Specifically, we test that using blocks from one request for another request + with different tokens should fail. + """ + manager = new_kv_cache_manager() + + # Create two requests with different token patterns + req0 = make_request("0", [i for i in range(48)]) # 3 blocks of sequential tokens + req1 = make_request("1", [i * 2 for i in range(48)]) # 3 blocks of even tokens + + # Allocate and compute blocks for req0 + computed_blocks_req0, _ = manager.get_computed_blocks(req0) + _ = manager.allocate_slots(req0, 48, 0, computed_blocks_req0) + + # Simulate forward pass + req0.append_output_token_ids(100) # Add output token + req0.num_computed_tokens = 48 # Mark all input tokens as computed + _ = manager.allocate_slots(req0, num_new_tokens=1) # Allocate slot for output token + + # Try to use req0's blocks for req1 - this should fail + with pytest.raises(Exception) as exc_info: + manager.allocate_slots(req1, 48, 48, computed_blocks_req0) + assert ( + "slot error: Insufficient capacity: need 48 tokens but only 0 available in mutable blocks" + in str(exc_info.value) + ) + + # Get computed blocks after forward pass + computed_blocks_req0, num_computed_tokens = manager.get_computed_blocks(req0) + assert len(computed_blocks_req0.blocks) == 3 # Should have 3 complete blocks + assert num_computed_tokens == 48 # All input tokens should be computed + + # Try to use req0's blocks for req1 - this should fail + with pytest.raises(Exception) as exc_info: + manager.allocate_slots(req1, 48, 48, computed_blocks_req0) + assert "slot error: computed block sequence hash mismatch" in str(exc_info.value) + + # Clean up + manager.free_block_hashes(req0) + manager.free_block_hashes(req1) From 97793e2f5f9f2d8386aef2e68498d1bbdbc8fc2c Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 2 Jul 2025 12:58:54 -0700 Subject: [PATCH 2/3] Fix tests after rebase --- lib/bindings/python/rust/llm/block_manager.rs | 55 +-- .../python/tests/test_block_manager.py | 395 ------------------ lib/bindings/python/tests/test_kvbm.py | 38 +- .../tests/test_kvbm_vllm_integration.py | 66 ++- .../data/logical/distributed_leader_worker.rs | 79 ++-- lib/llm/src/block_manager/distributed.rs | 2 +- 6 files changed, 120 insertions(+), 515 deletions(-) delete mode 100644 lib/bindings/python/tests/test_block_manager.py diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index bc7db4b4d8..2d2142558c 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -51,10 +51,10 @@ pub struct BlockManager { #[pymethods] impl BlockManager { #[new] - #[pyo3(signature = (worker_id, leader, page_size, device_num_blocks))] + #[pyo3(signature = (worker_id, leader = None, page_size = 32, device_num_blocks = 16))] fn new( worker_id: u64, - leader: distributed::KvbmLeader, + leader: Option, page_size: usize, device_num_blocks: usize, ) -> PyResult { @@ -85,29 +85,34 @@ impl BlockManager { .map_err(to_pyerr)?, ); - let (leader, rt) = leader.dissolve(); - - if leader.num_host_blocks() > 0 { - tracing::info!("Using {} host blocks", leader.num_host_blocks()); - config = config.host_layout( - dynamo_llm::block_manager::KvManagerLayoutConfig::builder() - .num_blocks(leader.num_host_blocks()) - .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) - .build() - .map_err(to_pyerr)?, - ); - } - - if leader.num_disk_blocks() > 0 { - tracing::info!("Using {} disk blocks", leader.num_disk_blocks()); - config = config.disk_layout( - dynamo_llm::block_manager::KvManagerLayoutConfig::builder() - .num_blocks(leader.num_disk_blocks()) - .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) - .build() - .map_err(to_pyerr)?, - ); - } + let (leader, rt) = if let Some(leader) = leader { + let (leader, rt) = leader.dissolve(); + if leader.num_host_blocks() > 0 { + tracing::info!("Using {} host blocks", leader.num_host_blocks()); + config = config.host_layout( + dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + .num_blocks(leader.num_host_blocks()) + .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + .build() + .map_err(to_pyerr)?, + ); + } + + if leader.num_disk_blocks() > 0 { + tracing::info!("Using {} disk blocks", leader.num_disk_blocks()); + config = config.disk_layout( + dynamo_llm::block_manager::KvManagerLayoutConfig::builder() + .num_blocks(leader.num_disk_blocks()) + .logical(Some(BlockParallelismStrategy::LeaderWorkerSharded)) + .build() + .map_err(to_pyerr)?, + ); + } + (Some(leader), rt) + } else { + tracing::info!("Leader not provided. Block transfer functionality will be disabled."); + (None, Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().map_err(to_pyerr)?)) + }; let config = config.build().map_err(to_pyerr)?; Ok(BlockManager { diff --git a/lib/bindings/python/tests/test_block_manager.py b/lib/bindings/python/tests/test_block_manager.py deleted file mode 100644 index 94c7b455db..0000000000 --- a/lib/bindings/python/tests/test_block_manager.py +++ /dev/null @@ -1,395 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import asyncio - -import pytest -import torch - -from dynamo.llm import BlockManager - -pytestmark = pytest.mark.pre_merge - - -WORKER_ID = 0 -NUM_LAYER = 5 -OUTER_DIM = 2 -PAGE_SIZE = 4 -INNER_DIM = 13 -DTYPE, TORCH_DTYPE = "FP32", torch.float32 -HOST_NUM_BLOCKS = 16 -DEVICE_NUM_BLOCKS = 16 -DEVICE_ID = 0 - - -def new_block_manager(): - return BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - HOST_NUM_BLOCKS, - DEVICE_NUM_BLOCKS, - DEVICE_ID, - ) - - -@pytest.fixture -def block_manager(): - return new_block_manager() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_block_manager_initialization(): - # Python should drop the BlockManager instance as soon as it goes out of scope, but - # it may not be garbage collected immediately, depending on the garbage collector. - BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM) - BlockManager(WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE) - BlockManager( - WORKER_ID, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM, DTYPE, HOST_NUM_BLOCKS - ) - BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - device_num_blocks=DEVICE_NUM_BLOCKS, - ) - BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - HOST_NUM_BLOCKS, - DEVICE_NUM_BLOCKS, - ) - BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - device_num_blocks=DEVICE_NUM_BLOCKS, - device_id=DEVICE_ID, - ) - BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - HOST_NUM_BLOCKS, - DEVICE_NUM_BLOCKS, - DEVICE_ID, - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_cpu_block_access(block_manager: BlockManager): - block_count = 2 - block_list = block_manager.allocate_host_blocks_blocking(block_count) - blocks = block_list.to_list() - assert len(blocks) == block_count - tensors = [torch.from_dlpack(b) for b in blocks] - for tensor in tensors: - assert tensor.get_device() == -1 # CPU - assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM) - assert tensor.dtype == TORCH_DTYPE - # print(tensors) - for tensor in tensors: - tensor[0][0][0][0][0] = 1.0 - tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 - # print(tensors) - blocks_ = block_list.to_list() - assert blocks is not blocks_ - assert len(blocks) == len(blocks_) - tensors_ = [torch.from_dlpack(b) for b in blocks_] - for tensor, tensor_ in zip(tensors, tensors_): - assert tensor is not tensor_ - assert tensor.shape == tensor_.shape - assert tensor.dtype == tensor_.dtype - assert torch.allclose(tensor, tensor_) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_gpu_block_access(block_manager: BlockManager): - block_count = 6 - block_list = block_manager.allocate_device_blocks_blocking(block_count) - blocks = block_list.to_list() - assert len(blocks) == block_count - tensors = [torch.from_dlpack(b) for b in blocks] - for tensor in tensors: - assert tensor.get_device() == DEVICE_ID # GPU - assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM) - assert tensor.dtype == TORCH_DTYPE - # print(tensors) - for tensor in tensors: - tensor[0][0][0][0][0] = 1.0 - tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 - # print(tensors) - blocks_ = block_list.to_list() - assert blocks is not blocks_ - assert len(blocks) == len(blocks_) - tensors_ = [torch.from_dlpack(b) for b in blocks_] - for tensor, tensor_ in zip(tensors, tensors_): - assert tensor is not tensor_ - assert tensor.shape == tensor_.shape - assert tensor.dtype == tensor_.dtype - assert torch.allclose(tensor, tensor_) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_block_list_iteration(block_manager: BlockManager): - block_count = 4 - block_list = await block_manager.allocate_host_blocks(block_count) - # Test __len__() - assert len(block_list) == block_count - # Test __getitem__() - for i in range(block_count): - block = block_list[i] - tensor = torch.from_dlpack(block) - tensor[0][0][0][0][0] = 1.0 + i - # Test __iter__() and __next__() - idx = 1.0 - for block in block_list: - tensor = torch.from_dlpack(block) - assert tensor[0][0][0][0][0] == idx - tensor[0][0][0][0][0] += 0.5 - idx += 1.0 - assert idx == 1.0 + block_count - # Test __iter__() should reset current index - idx = 1.0 - for block in block_list: - tensor = torch.from_dlpack(block) - assert tensor[0][0][0][0][0] == idx + 0.5 - idx += 1.0 - assert idx == 1.0 + block_count - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_block_copy_g1_g2(block_manager: BlockManager): - # Allocate device (G1) and host (G2) block - host_block_list = await block_manager.allocate_host_blocks(1) - device_block_list = await block_manager.allocate_device_blocks(1) - # Populate host block with unique values - host_tensor = torch.from_dlpack(host_block_list[0]) - for i in range(NUM_LAYER): - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - host_tensor[0][i][j][k][w] = ( - i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - # Copy host block to device block after permuting - permute_dims = (0, 2, 4, 3, 1) - device_tensor_ = torch.from_dlpack(device_block_list[0]).permute(*permute_dims) - device_tensor_.copy_(host_tensor.permute(*permute_dims)) - # Assert device block is contiguous and updated in block manager - device_tensor = torch.from_dlpack(device_block_list[0]) - for i in range(NUM_LAYER): - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - assert ( - device_tensor[0][i][j][k][w] - == i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - # Set host block to zero and assert updated in block manager - host_tensor_ = torch.from_dlpack(host_block_list[0]).permute(*permute_dims) - host_tensor_.zero_() - assert torch.all(host_tensor == 0) - # Copy device block back to host block - host_tensor_.copy_(device_tensor_) - # Assert host block is updated in block manager - for i in range(NUM_LAYER): - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - assert ( - host_tensor[0][i][j][k][w] - == i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_cpu_layer_access(block_manager: BlockManager): - block_list = block_manager.allocate_host_blocks_blocking(1) - block = block_list[0] - layers = block.to_list() - assert len(layers) == NUM_LAYER - tensors = [torch.from_dlpack(bl) for bl in layers] - for tensor in tensors: - assert tensor.get_device() == -1 # CPU - assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM) - assert tensor.dtype == TORCH_DTYPE - # print(tensors) - for tensor in tensors: - tensor[0][0][0][0][0] = 1.0 - tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 - # print(tensors) - layers_ = block.to_list() - assert layers is not layers_ - assert len(layers) == len(layers_) - tensors_ = [torch.from_dlpack(bl) for bl in layers_] - for tensor, tensor_ in zip(tensors, tensors_): - assert tensor is not tensor_ - assert tensor.shape == tensor_.shape - assert tensor.dtype == tensor_.dtype - assert torch.allclose(tensor, tensor_) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_gpu_layer_access(block_manager: BlockManager): - block_list = block_manager.allocate_device_blocks_blocking(1) - block = block_list[0] - layers = block.to_list() - assert len(layers) == NUM_LAYER - tensors = [torch.from_dlpack(bl) for bl in layers] - for tensor in tensors: - assert tensor.get_device() == DEVICE_ID # GPU - assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM) - assert tensor.dtype == TORCH_DTYPE - # print(tensors) - for tensor in tensors: - tensor[0][0][0][0][0] = 1.0 - tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 - # print(tensors) - layers_ = block.to_list() - assert layers is not layers_ - assert len(layers) == len(layers_) - tensors_ = [torch.from_dlpack(bl) for bl in layers_] - for tensor, tensor_ in zip(tensors, tensors_): - assert tensor is not tensor_ - assert tensor.shape == tensor_.shape - assert tensor.dtype == tensor_.dtype - assert torch.allclose(tensor, tensor_) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_block_iteration(block_manager: BlockManager): - block = (await block_manager.allocate_host_blocks(1))[0] - # Test __len__() - assert len(block) == NUM_LAYER - # Test __getitem__() - for i in range(NUM_LAYER): - layer = block[i] - tensor = torch.from_dlpack(layer) - tensor[0][0][0][0][0] = 1.0 + i - # Test __iter__() and __next__() - idx = 1.0 - for layer in block: - tensor = torch.from_dlpack(layer) - assert tensor[0][0][0][0][0] == idx - tensor[0][0][0][0][0] += 0.5 - idx += 1.0 - assert idx == 1.0 + NUM_LAYER - # Test __iter__() should reset current index - idx = 1.0 - for layer in block: - tensor = torch.from_dlpack(layer) - assert tensor[0][0][0][0][0] == idx + 0.5 - idx += 1.0 - assert idx == 1.0 + NUM_LAYER - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_block_layer_copy_g1_g2(block_manager: BlockManager): - # Allocate device (G1) and host (G2) block - host_block = (await block_manager.allocate_host_blocks(1))[0] - device_block = (await block_manager.allocate_device_blocks(1))[0] - # Populate host block at layer level with unique values - host_layer_tensors = [torch.from_dlpack(bl) for bl in host_block] - for i in range(NUM_LAYER): - host_layer_tensor = host_layer_tensors[i] - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - host_layer_tensor[0][0][j][k][w] = ( - i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - # Copy host block to device block after permuting - permute_dims = (0, 2, 4, 3, 1) - host_block_tensor_ = torch.from_dlpack(host_block).permute(*permute_dims) - device_block_tensor_ = torch.from_dlpack(device_block).permute(*permute_dims) - device_block_tensor_.copy_(host_block_tensor_) - # Assert device block is contiguous and updated in block manager at layer level - device_layer_tensors = [torch.from_dlpack(bl) for bl in device_block] - for i in range(NUM_LAYER): - device_layer_tensor = device_layer_tensors[i] - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - assert ( - device_layer_tensor[0][0][j][k][w] - == i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - # Set host block to zero and assert updated in block manager - host_block_tensor = torch.from_dlpack(host_block) - host_block_tensor.zero_() - assert torch.all(host_block_tensor_ == 0) - # Copy device block back to host block - host_block_tensor_.copy_(device_block_tensor_) - # Assert host block is updated in block manager - for i in range(NUM_LAYER): - for j in range(OUTER_DIM): - for k in range(PAGE_SIZE): - for w in range(INNER_DIM): - assert ( - host_block_tensor[0][i][j][k][w] - == i * OUTER_DIM * PAGE_SIZE * INNER_DIM - + j * PAGE_SIZE * INNER_DIM - + k * INNER_DIM - + w - ) - - -async def main(): - await test_block_manager_initialization() - await test_cpu_block_access(new_block_manager()) - await test_gpu_block_access(new_block_manager()) - await test_block_list_iteration(new_block_manager()) - await test_block_copy_g1_g2(new_block_manager()) - await test_cpu_layer_access(new_block_manager()) - await test_gpu_layer_access(new_block_manager()) - await test_block_iteration(new_block_manager()) - await test_block_layer_copy_g1_g2(new_block_manager()) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/lib/bindings/python/tests/test_kvbm.py b/lib/bindings/python/tests/test_kvbm.py index b667a1bc30..c9972f07d4 100644 --- a/lib/bindings/python/tests/test_kvbm.py +++ b/lib/bindings/python/tests/test_kvbm.py @@ -12,22 +12,17 @@ import torch from vllm.v1.request import Request, SamplingParams -from dynamo.llm import BlockManager -from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager +try: + from dynamo.llm import BlockManager + from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager + KVBM_NOT_AVAILABLE = False +except ImportError: + KVBM_NOT_AVAILABLE = True pytestmark = pytest.mark.pre_merge - -WORKER_ID = 0 -NUM_LAYER = 5 -OUTER_DIM = 2 PAGE_SIZE = 4 -INNER_DIM = 13 -DTYPE, TORCH_DTYPE = "FP32", torch.float32 -HOST_NUM_BLOCKS = 16 DEVICE_NUM_BLOCKS = 16 -DEVICE_ID = 0 - def new_request(): return Request( @@ -55,15 +50,10 @@ def new_kv_cache_manager(): try: return KvbmCacheManager( BlockManager( - WORKER_ID, - NUM_LAYER, - OUTER_DIM, - PAGE_SIZE, - INNER_DIM, - DTYPE, - HOST_NUM_BLOCKS, - DEVICE_NUM_BLOCKS, - DEVICE_ID, + worker_id=0, + leader=None, + page_size=PAGE_SIZE, + device_num_blocks=DEVICE_NUM_BLOCKS, ) ) except Exception as e: @@ -72,13 +62,17 @@ def new_kv_cache_manager(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") -async def test_kvbm(block_manager: KvbmCacheManager): +@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available") +async def test_kvbm(): """ Tests the KVBM kv_cache_manager APIs. Args: block_manager: The KVBM cache manager. """ + + block_manager = new_kv_cache_manager() + request_1 = new_request() request_2 = new_request() request_3 = new_request() @@ -133,7 +127,7 @@ async def main(): """ Main function to run the test. """ - await test_kvbm(new_kv_cache_manager()) + await test_kvbm() if __name__ == "__main__": diff --git a/lib/bindings/python/tests/test_kvbm_vllm_integration.py b/lib/bindings/python/tests/test_kvbm_vllm_integration.py index 04ee8b9d7e..38dcf991f8 100644 --- a/lib/bindings/python/tests/test_kvbm_vllm_integration.py +++ b/lib/bindings/python/tests/test_kvbm_vllm_integration.py @@ -21,18 +21,16 @@ except ImportError: VLLM_NOT_AVAILABLE = True -from dynamo.llm import BlockManager -from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager - +try: + from dynamo.llm import BlockManager + from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager + KVBM_NOT_AVAILABLE = False +except: + KVBM_NOT_AVAILABLE = True def new_kv_cache_manager( - worker_id: int = 0, - num_layer: int = 1, - outer_dim: int = 1, - page_size: int = 16, - inner_dim: int = 1, - device_id: int = 0, num_blocks: int = 11, + page_size: int = 16 ): """ Creates a new KVBM cache manager. @@ -40,17 +38,13 @@ def new_kv_cache_manager( Returns: KvbmCacheManager: The KVBM cache manager. """ + return KvbmCacheManager( BlockManager( - worker_id, - num_layer, - outer_dim, - page_size, - inner_dim, - "FP32", # dtype - num_blocks, # host_num_blocks - num_blocks, # device_num_blocks - device_id, + worker_id=0, + leader=None, + page_size=page_size, + device_num_blocks=num_blocks, ) ) @@ -95,7 +89,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_prefill(): """ Tests the KvbmCacheManager's prefill functionality. @@ -146,17 +140,13 @@ def test_prefill(): for block in computed_blocks.blocks: assert block._block_hash is not None - print(block) # Clean up + del computed_blocks + manager.free_block_hashes(req0) - manager.free_block_hashes(req1) - # TODO(oandreeva): - # Currently need a delay here, since there's an intermittency. - # Not all blocks become available for KVBM to assign to the new request. - # Heppens with `free` and `free_block_hashes` calls. - time.sleep(0.5) + manager.free_block_hashes(req1) # Cache miss and eviction. req3 = make_request("3", [24] * (16 * 11)) @@ -174,8 +164,6 @@ def test_prefill(): assert block._block_hash is None assert block.block_id == expected_block_id - time.sleep(0.5) - @pytest.mark.skip(reason="KVBM needs to support reset_prefix_cache") def test_prefill_plp(): @@ -294,7 +282,7 @@ def test_prefill_plp(): manager.free(req2) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_decode(): manager = new_kv_cache_manager() @@ -362,7 +350,7 @@ def test_decode(): manager.free_block_hashes(req0) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_evict(): manager = new_kv_cache_manager() used_blocks = set() @@ -428,10 +416,10 @@ def test_evict(): # assert manager.block_pool.free_block_queue.num_free_blocks == 7 -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_hash_block_correct_reuse(): """ - This tests when a previously cached block is reused as a new block, + This tests when a previously cached block is reused as a new block, its hash metadata should be correctly reset. """ block_size = 16 @@ -459,9 +447,8 @@ def test_hash_block_correct_reuse(): assert computed_blocks.blocks[0].block_id == 0 # Deallocate the block. + del computed_blocks manager.free(req) - # Note(oandreeva): need to fix this in the kvbm core to not depend on time.sleep() - time.sleep(2) # Allocate new blocks, last one is partial not full, make sure hash info on the # blocks are cleared. @@ -480,7 +467,7 @@ def test_hash_block_correct_reuse(): assert blocks.blocks[1].block_hash is None -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_computed_blocks_not_evicted(): """ Test that the computed blocks are not evicted when getting new blocks @@ -530,8 +517,7 @@ def test_computed_blocks_not_evicted(): # Free the blocks. manager.free(req0) manager.free(req1) - # Note(oandreeva): need to fix this in the kvbm core to not depend on time.sleep() - time.sleep(2) + del computed_blocks # Now if we have a cache hit on the block_id 0, we should evict the block_id 1 # cached block rather than the first one. @@ -578,7 +564,7 @@ def _test_mm_prefix_caching(): pass -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_cache_key_salting(): """ This tests that cache salts are applied during hashing and the cache @@ -649,7 +635,7 @@ def test_cache_key_salting(): """ -@pytest.mark.skipif(VLLM_NOT_AVAILABLE, reason="VLLM not available") +@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ This is a unit test that tests the correctness of the allocate_slots @@ -772,7 +758,7 @@ def _test_eagle_with_sliding_window(): Test Eagle behavior with sliding window.""" pass - +@pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available") def test_kvbm_wrong_blocks_provided(): """ Tests that providing wrong blocks to allocate_slots results in an error. diff --git a/lib/llm/src/block_manager/block/data/logical/distributed_leader_worker.rs b/lib/llm/src/block_manager/block/data/logical/distributed_leader_worker.rs index b4ba56b414..6f1e425b1f 100644 --- a/lib/llm/src/block_manager/block/data/logical/distributed_leader_worker.rs +++ b/lib/llm/src/block_manager/block/data/logical/distributed_leader_worker.rs @@ -13,7 +13,9 @@ type TransferRequest = (BlockTransferRequest, oneshot::Sender<()>); #[derive(Clone)] pub struct DistributedLeaderWorkerResources { - transfer_tx: mpsc::UnboundedSender, + /// Make this an option to make testing easier. + // TODO(jothomson): We should be using NullResources for this. + transfer_tx: Option>, } impl std::fmt::Debug for DistributedLeaderWorkerResources { @@ -23,19 +25,28 @@ impl std::fmt::Debug for DistributedLeaderWorkerResources { } impl DistributedLeaderWorkerResources { - pub fn new(leader: Arc, cancel_token: CancellationToken) -> anyhow::Result { - let (transfer_tx, transfer_rx) = mpsc::unbounded_channel(); - - CriticalTaskExecutionHandle::new( - move |cancel_token| async move { - Self::worker(leader, transfer_rx, cancel_token).await - }, - cancel_token, - "DistributedLeaderWorkerResources", - ) - .map_err(|e| anyhow::anyhow!("Failed to create DistributedLeaderWorkerResources: {}", e))?.detach(); - - Ok(Self { transfer_tx }) + pub fn new( + leader: Option>, + cancel_token: CancellationToken, + ) -> anyhow::Result { + if let Some(leader) = leader { + let (transfer_tx, transfer_rx) = mpsc::unbounded_channel(); + + CriticalTaskExecutionHandle::new( + move |cancel_token| async move { + Self::worker(leader, transfer_rx, cancel_token).await + }, + cancel_token, + "DistributedLeaderWorkerResources", + ) + .map_err(|e| anyhow::anyhow!("Failed to create DistributedLeaderWorkerResources: {}", e))?.detach(); + + Ok(Self { + transfer_tx: Some(transfer_tx), + }) + } else { + Ok(Self { transfer_tx: None }) + } } fn get_pool(data: &impl BlockDataExt) -> BlockTransferPool { @@ -90,25 +101,29 @@ impl LogicalResources for DistributedLeaderWorkerResources { RB: BlockDataProvider>, WB: WritableBlock + BlockDataProviderMut>, { - let source_pool = Self::get_pool(sources[0].block_data()); - let target_pool = Self::get_pool(targets[0].block_data()); - - let source_idxs = sources.iter().map(|source| source.block_data().block_id()); - let target_idxs = targets.iter().map(|target| target.block_data().block_id()); - - let request = BlockTransferRequest::new( - source_pool, - target_pool, - source_idxs.zip(target_idxs).collect(), - ); - - let (tx, rx) = oneshot::channel(); - self.transfer_tx.send((request, tx)).unwrap(); - - if notify { - Ok(Some(rx)) + if let Some(transfer_tx) = &self.transfer_tx { + let source_pool = Self::get_pool(sources[0].block_data()); + let target_pool = Self::get_pool(targets[0].block_data()); + + let source_idxs = sources.iter().map(|source| source.block_data().block_id()); + let target_idxs = targets.iter().map(|target| target.block_data().block_id()); + + let request = BlockTransferRequest::new( + source_pool, + target_pool, + source_idxs.zip(target_idxs).collect(), + ); + + let (tx, rx) = oneshot::channel(); + transfer_tx.send((request, tx)).unwrap(); + + if notify { + Ok(Some(rx)) + } else { + Ok(None) + } } else { - Ok(None) + panic!("Block transfer functionality is disabled."); } } } diff --git a/lib/llm/src/block_manager/distributed.rs b/lib/llm/src/block_manager/distributed.rs index 0f9a66fa57..dad28c1a1d 100644 --- a/lib/llm/src/block_manager/distributed.rs +++ b/lib/llm/src/block_manager/distributed.rs @@ -226,7 +226,7 @@ mod tests { .build()?; let resources = - DistributedLeaderWorkerResources::new(Arc::new(leader), cancel_token.child_token())?; + DistributedLeaderWorkerResources::new(Some(Arc::new(leader)), cancel_token.child_token())?; let block_manager = KvBlockManager::< Logical, From 7fe54de25e09054bdff55b678350d31330dfcf80 Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 2 Jul 2025 13:03:59 -0700 Subject: [PATCH 3/3] precommit + fmt --- lib/bindings/python/rust/llm/block_manager.rs | 12 ++++- lib/bindings/python/tests/test_kvbm.py | 2 + .../tests/test_kvbm_vllm_integration.py | 48 +++++++++++++------ 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/lib/bindings/python/rust/llm/block_manager.rs b/lib/bindings/python/rust/llm/block_manager.rs index 2d2142558c..56ac9840ba 100644 --- a/lib/bindings/python/rust/llm/block_manager.rs +++ b/lib/bindings/python/rust/llm/block_manager.rs @@ -97,7 +97,7 @@ impl BlockManager { .map_err(to_pyerr)?, ); } - + if leader.num_disk_blocks() > 0 { tracing::info!("Using {} disk blocks", leader.num_disk_blocks()); config = config.disk_layout( @@ -111,7 +111,15 @@ impl BlockManager { (Some(leader), rt) } else { tracing::info!("Leader not provided. Block transfer functionality will be disabled."); - (None, Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().map_err(to_pyerr)?)) + ( + None, + Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(to_pyerr)?, + ), + ) }; let config = config.build().map_err(to_pyerr)?; diff --git a/lib/bindings/python/tests/test_kvbm.py b/lib/bindings/python/tests/test_kvbm.py index c9972f07d4..d2e1507034 100644 --- a/lib/bindings/python/tests/test_kvbm.py +++ b/lib/bindings/python/tests/test_kvbm.py @@ -15,6 +15,7 @@ try: from dynamo.llm import BlockManager from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager + KVBM_NOT_AVAILABLE = False except ImportError: KVBM_NOT_AVAILABLE = True @@ -24,6 +25,7 @@ PAGE_SIZE = 4 DEVICE_NUM_BLOCKS = 16 + def new_request(): return Request( request_id=str(uuid.uuid4()), diff --git a/lib/bindings/python/tests/test_kvbm_vllm_integration.py b/lib/bindings/python/tests/test_kvbm_vllm_integration.py index 38dcf991f8..672baad113 100644 --- a/lib/bindings/python/tests/test_kvbm_vllm_integration.py +++ b/lib/bindings/python/tests/test_kvbm_vllm_integration.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import time from typing import Optional import pytest @@ -24,14 +23,13 @@ try: from dynamo.llm import BlockManager from dynamo.llm.vllm_integration.kv_cache_manager import KvbmCacheManager + KVBM_NOT_AVAILABLE = False -except: +except ImportError: KVBM_NOT_AVAILABLE = True -def new_kv_cache_manager( - num_blocks: int = 11, - page_size: int = 16 -): + +def new_kv_cache_manager(num_blocks: int = 11, page_size: int = 16): """ Creates a new KVBM cache manager. @@ -89,7 +87,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_prefill(): """ Tests the KvbmCacheManager's prefill functionality. @@ -282,7 +283,10 @@ def test_prefill_plp(): manager.free(req2) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_decode(): manager = new_kv_cache_manager() @@ -350,7 +354,10 @@ def test_decode(): manager.free_block_hashes(req0) -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_evict(): manager = new_kv_cache_manager() used_blocks = set() @@ -416,10 +423,13 @@ def test_evict(): # assert manager.block_pool.free_block_queue.num_free_blocks == 7 -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_hash_block_correct_reuse(): """ - This tests when a previously cached block is reused as a new block, + This tests when a previously cached block is reused as a new block, its hash metadata should be correctly reset. """ block_size = 16 @@ -467,7 +477,10 @@ def test_hash_block_correct_reuse(): assert blocks.blocks[1].block_hash is None -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_computed_blocks_not_evicted(): """ Test that the computed blocks are not evicted when getting new blocks @@ -564,7 +577,10 @@ def _test_mm_prefix_caching(): pass -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_cache_key_salting(): """ This tests that cache salts are applied during hashing and the cache @@ -635,7 +651,10 @@ def test_cache_key_salting(): """ -@pytest.mark.skipif(VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, reason="VLLM not available or KVBM not available") +@pytest.mark.skipif( + VLLM_NOT_AVAILABLE or KVBM_NOT_AVAILABLE, + reason="VLLM not available or KVBM not available", +) def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ This is a unit test that tests the correctness of the allocate_slots @@ -758,6 +777,7 @@ def _test_eagle_with_sliding_window(): Test Eagle behavior with sliding window.""" pass + @pytest.mark.skipif(KVBM_NOT_AVAILABLE, reason="KVBM not available") def test_kvbm_wrong_blocks_provided(): """