diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index cce2fb2c4814..6129752bcdd6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" +from typing import Optional + import pytest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -15,7 +17,8 @@ def make_request(request_id, prompt_token_ids, mm_positions=None, - mm_hashes=None): + mm_hashes=None, + prompt_logprobs: Optional[int] = None): if mm_positions is None: multi_modal_inputs = None else: @@ -28,7 +31,8 @@ def make_request(request_id, multi_modal_inputs=multi_modal_inputs, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), + sampling_params=SamplingParams(max_tokens=17, + prompt_logprobs=prompt_logprobs), eos_token_id=100, arrival_time=0, lora_request=None, @@ -144,6 +148,110 @@ def test_prefill(): assert manager.block_pool.free_block_queue.free_list_tail is None +def test_prefill_plp(): + '''Test prefill with APC and some prompt logprobs (plp) requests. + + 1. Schedule plp request and validate APC block allocation + 2. Schedule non-plp request and validate blocks + 3. Schedule plp request; no hit should occur; validate blocks + ''' + manager = KVCacheManager( + block_size=16, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=16, + ) + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(16)] + + # Request #0 is a prompt logprobs request + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids, prompt_logprobs=5) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert not computed_blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots(req0, 55, computed_blocks) + assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + req0_block_hashes = [b.block_hash for b in blocks] + + # Check full block metadata + parent_block_hash = None + for block_id in (0, 1, 2): + block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) + block_hash = hash_block_tokens(parent_block_hash, block_tokens) + assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + parent_block_hash = block_hash.hash_value + + # Check partial/preallocated block metadata + for block_id in (3, 4): + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + # Request #1 is a non-prompt-logprobs request: + # Cache hit in the common prefix when the original block is still in use. + # Incomplete 1 block (5 tokens) + unique_token_ids = [3] * 5 + req1 = make_request("1", common_token_ids + unique_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert num_computed_tokens == 3 * 16 + num_new_tokens = 53 - 3 * 16 + blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) + assert [b.block_id for b in blocks] == [5, 6] + for block in computed_blocks: + assert block.ref_cnt == 2 + + # At this point, we should have 3 free blocks left. + assert manager.block_pool.free_block_queue.num_free_blocks == 3 + + manager.free(req0) + manager.free(req1) + + # All blocks should be available. + assert manager.block_pool.free_block_queue.num_free_blocks == 10 + # The order should be + # [unallocated (7, 8, 9)] + # [unique_req0 (4, 3)] + # [unique_req1 (6, 5)] + # [common (2, 1, 0)] + assert [ + b.block_id + for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] + + # Request #2 is a prompt-logprobs request: + # NO cache hit in the common prefix; duplicates request #0 cached blocks + unique_token_ids = [3] * 6 + req2 = make_request("2", + common_token_ids + unique_token_ids, + prompt_logprobs=5) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert not computed_blocks + assert num_computed_tokens == 0 + blocks = manager.allocate_slots(req2, 55, computed_blocks) + block_ids = [b.block_id for b in blocks] + # Duplicate cached blocks have different ids but same hashes vs request #0 + assert [b.block_hash for b in blocks] == req0_block_hashes + assert block_ids != [0, 1, 2, 3, 4] + + # Request #2 block hashes are valid since request #0 hashes are. + # Check block reference counts. + for block_id in block_ids: + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + manager.free(req2) + + def test_decode(): manager = KVCacheManager( block_size=16, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 738ab2ef03de..9413373390fe 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional +import pytest + from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -16,7 +18,21 @@ def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, ) -> Scheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, @@ -31,11 +47,16 @@ def create_scheduler( dtype="float16", seed=42, ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", + **kwargs_cache, ) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -54,16 +75,16 @@ def create_scheduler( ) -def create_requests( - num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, -): +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, - stop_token_ids=stop_token_ids) + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): if mm_positions is not None: @@ -122,9 +143,18 @@ def test_get_num_unfinished_requests(): assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 -def test_schedule(): - scheduler = create_scheduler() - requests = create_requests(num_requests=10) +@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ + (None, None), + (True, 5), +]) +def test_schedule(enable_prefix_caching: Optional[bool], + prompt_logprobs: Optional[int]): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) + requests = create_requests(num_requests=10, + prompt_logprobs=prompt_logprobs) for request in requests: scheduler.add_request(request) @@ -427,14 +457,21 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] -def test_schedule_concurrent_batches(): +@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ + (None, None), + (True, 5), +]) +def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], + prompt_logprobs: Optional[int]): scheduler = create_scheduler( max_num_batched_tokens=1024, max_num_seqs=2, + enable_prefix_caching=enable_prefix_caching, ) requests = create_requests( num_requests=2, num_tokens=512, + prompt_logprobs=prompt_logprobs, ) # Schedule the first request. diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index e7b91aeb0fbd..0de0026eb284 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -6,7 +6,6 @@ import pytest -from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import SamplingParams from vllm.assets.image import ImageAsset from vllm.engine.arg_utils import AsyncEngineArgs @@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM, return count, request_id -@pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -@pytest.mark.asyncio -async def test_async_llm_refuses_prompt_logprobs_with_apc( - monkeypatch, output_kind: RequestOutputKind): - """Test passes if AsyncLLM raises an exception when it is configured - for automatic prefix caching and it receives a request with - prompt_logprobs enabled, which is incompatible.""" - # TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a - # better way to test V1 so that in the future when we switch, we don't - # have to change all the tests. - monkeypatch.setenv("VLLM_USE_V1", "1") - # Create AsyncLLM engine with APC - apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m", - enable_prefix_caching=True, - gpu_memory_utilization=0.8, - disable_log_requests=True) - engine = AsyncLLM.from_engine_args(apc_engine_args) - try: - with pytest.raises(ValueError) as excinfo: - # Issue a request with prompt logprobs enabled, which should fail - await asyncio.create_task( - generate(engine, - "request-0", - TEXT_PROMPT, - output_kind, - 10, - prompt_logprobs=5)) - # Validate exception string is correct - assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG - finally: - # Shut down engine - engine.shutdown() - - @pytest.mark.parametrize( "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) @pytest.mark.parametrize("engine_args_and_prompt", diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 43b16d3e5a29..5446653cc3a2 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -5,7 +5,6 @@ import pytest -from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG from vllm import LLM, SamplingParams MODEL = "facebook/opt-125m" @@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: raise AssertionError( f"{len(completion_counts)} unique completions; expected" f" {n}. Repeats: {repeats}") - - -def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc): - """Test passes if LLMEngine raises an exception when it is configured - for automatic prefix caching and it receives a request with - prompt_logprobs enabled, which is incompatible.""" - model: LLM = vllm_model_apc.model - with pytest.raises(ValueError) as excinfo: - model.generate( - "Hello, my name is", - SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5)) - - # Validate exception string is correct - assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 02baa4801a47..f0e344cfa6fc 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -30,9 +30,6 @@ STOP_STRINGS = ["I love working on", "company by far", "brother in"] PROMPT_LEN = 5 -PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet " - "supported on VLLM V1.") - random.seed(42) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index d564a8c2e7a7..9715573e3f14 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -1,24 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from collections.abc import Generator import pytest import torch from tests.kernels.utils import override_backend_env_variable from tests.v1.sample.utils import ( + BatchLogprobsComposition, BatchLogprobsSpecType, assert_incr_detok_str_matches_non_incr_detok_str, compute_correct_cumulative_logprob, get_test_batch) from vllm import SamplingParams -from ...conftest import VllmRunner +from ...conftest import HfRunner, VllmRunner MODEL = "meta-llama/Llama-3.2-1B-Instruct" DTYPE = "half" +NONE = BatchLogprobsComposition.NONE +SAMPLE = BatchLogprobsComposition.SAMPLE +PROMPT = BatchLogprobsComposition.PROMPT +SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT -@pytest.fixture(scope="module") -def vllm_model(vllm_runner): + +@pytest.fixture( + scope="module", + # Parameterize APC + params=[False, True]) +def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: with vllm_runner( MODEL, dtype=DTYPE, @@ -31,22 +41,22 @@ def vllm_model(vllm_runner): enforce_eager=True, #TODO: enable this once we support it for # prompt logprobs. - enable_prefix_caching=False, + enable_prefix_caching=request.param, gpu_memory_utilization=0.5, ) as vllm_model: yield vllm_model @pytest.fixture(scope="module") -def hf_model(hf_runner): +def hf_model(hf_runner) -> Generator[HfRunner, None, None]: with hf_runner(MODEL, dtype=DTYPE) as hf_model: yield hf_model def _repeat_logprob_config( test_prompts, - logprob_prompt_logprob_list: list[tuple], -) -> list[tuple]: + logprob_prompt_logprob_list: BatchLogprobsSpecType, +) -> BatchLogprobsSpecType: """Ensure each test prompt has a logprob config. A logprob config specifies the optional (i.e. @@ -91,42 +101,17 @@ def _repeat_logprob_config( return logprob_prompt_logprob_list -def _test_case_get_logprobs_and_prompt_logprobs( - hf_model, - vllm_model, - batch_logprobs_composition: str, +def _run_and_validate( + vllm_model: VllmRunner, + test_prompts: list[str], + vllm_sampling_params: SamplingParams, + hf_logprobs: list[list[torch.Tensor]], + hf_outputs: list[tuple[list[int], str]], + logprob_prompt_logprob_list: BatchLogprobsSpecType, temperature: float, - example_prompts, + max_tokens: int, + do_apc: bool, ) -> None: - test_prompts = example_prompts - - max_tokens = 5 - hf_outputs = hf_model.generate_greedy( - test_prompts, - max_tokens=max_tokens, - ) - hf_logprobs = hf_model.generate_greedy_logprobs( - test_prompts, - max_tokens=max_tokens, - ) - - # Batch has mixed sample params - # (different logprobs/prompt logprobs combos) - logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) - - # Ensure that each test prompt has a logprob config for testing - logprob_prompt_logprob_list = _repeat_logprob_config( - test_prompts, logprob_prompt_logprob_list) - # Generate SamplingParams - vllm_sampling_params = [ - SamplingParams(max_tokens=max_tokens, - logprobs=num_lp, - prompt_logprobs=num_plp, - temperature=temperature, - seed=1984) - for num_lp, num_plp in logprob_prompt_logprob_list - ] - vllm_results = vllm_model.model.generate( test_prompts, sampling_params=vllm_sampling_params) @@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs( assert vllm_result.prompt_logprobs is None -#@pytest.mark.skip_global_cleanup @pytest.mark.parametrize("batch_logprobs_composition", - ["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) + [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]) @pytest.mark.parametrize("temperature", [0.0, 2.0]) def test_get_logprobs_and_prompt_logprobs( hf_model, vllm_model, - batch_logprobs_composition: str, + batch_logprobs_composition: BatchLogprobsComposition, temperature: float, example_prompts, ) -> None: @@ -292,25 +276,70 @@ def test_get_logprobs_and_prompt_logprobs( batch_logprobs_composition controls the logprobs configurations for requests in the batch under test. + APC tests run two test iterations so that cache hits occur. + + To save time, only test one APC-enabled scenario + (sample & prompt logprobs enabled, temperature>0.0). + Args: - hf_model - vllm_model + hf_model: HuggingFace reference model fixture + vllm_model: vLLM model fixture batch_logprobs_composition: logprobs configuration for test batch - example_prompts - monkeypatch + temperature: "temperature" sampling parameter + example_prompts: example prompt fixture """ - _test_case_get_logprobs_and_prompt_logprobs( - hf_model=hf_model, - vllm_model=vllm_model, - batch_logprobs_composition=batch_logprobs_composition, - temperature=temperature, - example_prompts=example_prompts) + do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching + if do_apc and (temperature < 2.0 + or batch_logprobs_composition != SAMPLE_PROMPT): + # Skip some test-cases to save time. + pytest.skip() + test_prompts = example_prompts + + max_tokens = 5 + hf_outputs = hf_model.generate_greedy( + test_prompts, + max_tokens=max_tokens, + ) + hf_logprobs = hf_model.generate_greedy_logprobs( + test_prompts, + max_tokens=max_tokens, + ) + + # Batch has mixed sample params + # (different logprobs/prompt logprobs combos) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) + + # Ensure that each test prompt has a logprob config for testing + logprob_prompt_logprob_list = _repeat_logprob_config( + test_prompts, logprob_prompt_logprob_list) + # Generate SamplingParams + vllm_sampling_params = [ + SamplingParams(max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984) + for num_lp, num_plp in logprob_prompt_logprob_list + ] + for _ in range(2 if do_apc else 1): + _run_and_validate( + vllm_model=vllm_model, + test_prompts=test_prompts, + vllm_sampling_params=vllm_sampling_params, + hf_logprobs=hf_logprobs, + hf_outputs=hf_outputs, + logprob_prompt_logprob_list=logprob_prompt_logprob_list, + temperature=temperature, + max_tokens=max_tokens, + do_apc=do_apc) def test_max_logprobs(monkeypatch): """vLLM v1 engine should fail a request with `logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs` + + APC should not matter as this test checks basic request validation. Args: monkeypatch @@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch): runner.generate(["Hello world"], sampling_params=bad_sampling_params) -def test_none_logprobs(vllm_model, example_prompts, monkeypatch): +def test_none_logprobs(vllm_model, example_prompts): """Engine should return `logprobs` and `prompt_logprobs` as `None` Args: vllm_model: vLLM model fixture example_prompts: list of example prompts (test fixture) - monkeypatch: supports editing env vars and rolling back changes - after the test """ max_tokens = 5 @@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch): assert results_logprobs_none[i].prompt_logprobs is None -def test_zero_logprobs(vllm_model, example_prompts, monkeypatch): +def test_zero_logprobs(vllm_model, example_prompts): """Engine should return sampled token and prompt token logprobs Args: vllm_model: vLLM model fixture example_prompts: list of example prompts (test fixture) - monkeypatch: supports editing env vars and rolling back changes - after the test """ max_tokens = 5 diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index c69d0d49c46f..f540895bbf14 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -1,27 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 import re +from enum import Enum +from typing import Optional from vllm import CompletionOutput -def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: +class BatchLogprobsComposition(Enum): + """Types of logprobs configs to include in test batch""" + NONE = 0 + SAMPLE = 1 + PROMPT = 2 + SAMPLE_PROMPT = 3 + + +BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]] + + +def get_test_batch( + batch_logprobs_composition: BatchLogprobsComposition +) -> BatchLogprobsSpecType: """Generate logprobs configs for a batch of requests A given request's logprobs configuration is (1) num_sample_logprobs and (2) num_prompt_logprobs. The batch logprobs configuration is the list of request logprobs configs. - batch_logprobs_composition == "NONE" yields a batch with no sample or prompt + batch_logprobs_composition == NONE yields a batch with no sample or prompt logprobs - batch_logprobs_composition == "SAMPLE" yields a batch with some requests + batch_logprobs_composition == SAMPLE yields a batch with some requests configured for sample logprobs only, and others configured for no logprobs - batch_logprobs_composition == "PROMPT" yields a batch with some requests + batch_logprobs_composition == PROMPT yields a batch with some requests configured for prompt logprobs only, and others configured for no logprobs - batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some + batch_logprobs_composition == SAMPLE_PROMPT yields a batch with some requests configured for sample logprobs and prompt logprobs, some configured for only sample logprobs or only prompt logprobs, and some configured for no logprobs @@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) tuples """ - if batch_logprobs_composition == "NONE": + if batch_logprobs_composition == BatchLogprobsComposition.NONE: # No requests with sample or prompt logprobs return [(None, None)] - elif batch_logprobs_composition == "SAMPLE": + elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE: # Requests requiring sample logprobs or no logprobs return [ (None, None), @@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: (5, None), (3, None), ] - elif batch_logprobs_composition == "PROMPT": + elif batch_logprobs_composition == BatchLogprobsComposition.PROMPT: # Requests requiring prompt logprobs or no logprobs return [ (None, None), @@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: (None, 6), (None, 5), ] - elif batch_logprobs_composition == "SAMPLE_PROMPT": + elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE_PROMPT: # Requests requiring either no logprobs, just # sample logprobs, just prompt logprobs, or # both sample and prompt logprobs diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6c6be01a2ff7..5cfe2b96865a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -105,8 +105,6 @@ def get_computed_blocks( # Prefix caching is disabled. return [], 0 - computed_blocks = [] - # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] @@ -114,24 +112,31 @@ def get_computed_blocks( block_hashes = hash_request_tokens(self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break - self.prefix_cache_stats.requests += 1 - self.prefix_cache_stats.queries += len(block_hashes) - self.prefix_cache_stats.hits += len(computed_blocks) - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + if request.sampling_params.prompt_logprobs is None: + # Check for cache hits + computed_blocks = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash + # is not in the cached_block_hash_to_id, the following + # block hashes are not computed yet for sure. + if cached_block := self.block_pool.get_cached_block( + block_hash): + computed_blocks.append(cached_block) + else: + break + + self.prefix_cache_stats.queries += len(block_hashes) + self.prefix_cache_stats.hits += len(computed_blocks) + + # NOTE(woosuk): Since incomplete blocks are not eligible for + # sharing, `num_computed_tokens` is always a multiple of + # `block_size`. + num_computed_tokens = len(computed_blocks) * self.block_size + return computed_blocks, num_computed_tokens + else: + # Skip cache hits for prompt logprobs + return [], 0 def allocate_slots( self, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b3226a280d8b..247fb046e81a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -72,12 +72,6 @@ def _validate_logprobs( f"Requested prompt logprobs of {params.prompt_logprobs}, " f"which is greater than max allowed: {max_logprobs}") - # TODO(andy): enable this in follow up by recomputing. - if (params.prompt_logprobs is not None - and self.cache_config.enable_prefix_caching): - raise ValueError("Prefix caching with prompt logprobs not yet " - "supported on VLLM V1.") - def _validate_sampling_params( self, params: SamplingParams,