diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ab7aa02823ab..347f98c772ff 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -900,3 +900,19 @@ def test_get_kv_cache_config(): with pytest.raises(NotImplementedError): get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + + # Test num_gpu_blocks_override + vllm_config.cache_config.num_gpu_blocks_override = 16 + kv_cache_config_override_blocks = get_kv_cache_config( + vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_override_blocks == KVCacheConfig( + num_blocks=16, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 16, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) + ]) \ No newline at end of file diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6d4bcfe64a35..9489bcf433fd 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -660,6 +660,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int, logger.info( "Overriding num_gpu_blocks=%d with " "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override return num_blocks