diff --git a/docs/features/README.md b/docs/features/README.md index 05ce0b57a9fc..537b5862e39f 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -52,7 +52,7 @@ th:not(:first-child) { | [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | | best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | | | beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | | -| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | \* Chunked prefill and prefix caching are only applicable to last-token pooling. ^ LoRA is only applicable to the language backbone of multimodal models. @@ -77,4 +77,4 @@ th:not(:first-child) { | multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ | | best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](gh-issue:25097) | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:25097) | diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 09f43a793db2..f7544da6a3a7 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import Callable, Optional +from typing import Any, Callable, Optional import pytest import torch @@ -29,6 +29,7 @@ UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request +from vllm.v1.utils import tensor_data # yapf: enable @@ -47,12 +48,13 @@ def _auto_init_hash_fn(request): def make_request( request_id: str, - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], block_size: int = 3, hash_fn: Callable = hash, mm_positions: Optional[list[PlaceholderRange]] = None, mm_hashes: Optional[list[str]] = None, cache_salt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, ): mm_features = [] if mm_positions is not None: @@ -72,6 +74,7 @@ def make_request( pooling_params=None, eos_token_id=100, lora_request=None, + prompt_embeds=prompt_embeds, cache_salt=cache_salt, block_hasher=get_request_block_hasher(block_size, hash_fn)) @@ -441,7 +444,8 @@ def test_hash_block_tokens(hash_fn): block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) + expected = hash_fn( + (parent_block_hash, curr_block_token_ids, None, extra_keys)) assert block_hash == expected @@ -462,9 +466,9 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) + (kv_cache_utils.NONE_HASH, (0, 1, 2), None, ("hash1", ))) assert block_hashes[1] == hash_fn( - (block_hashes[0], (3, 4, 5), ("hash2", ))) + (block_hashes[0], (3, 4, 5), None, ("hash2", ))) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -510,8 +514,8 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): assert len(block_hashes) == 2 assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) - assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) + (kv_cache_utils.NONE_HASH, (0, 1, 2), None, None)) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None, None)) def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats: @@ -1452,3 +1456,153 @@ def test_merge_mla_spec(): ] with pytest.raises(AssertionError): kv_cache_specs[0].merge(kv_cache_specs) + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_hash_block_tokens_with_prompt_embeds(hash_fn: Callable[[Any], bytes]): + parent_block_hash = BlockHash(b"123") + curr_block_token_ids = (1, 2, 3) + extra_keys = ("key1", "key2") + prompt_embeds = torch.randn((2, 3)) + + block_hash_with_embeds = hash_block_tokens(hash_fn, parent_block_hash, + curr_block_token_ids, + extra_keys, prompt_embeds) + + prompt_embeds_bytes = tensor_data(prompt_embeds).tobytes() + expected = hash_fn((parent_block_hash, curr_block_token_ids, + prompt_embeds_bytes, extra_keys)) + assert block_hash_with_embeds == expected + + block_hash_without_embeds = hash_block_tokens(hash_fn, parent_block_hash, + curr_block_token_ids, + extra_keys, None) + expected_without = hash_fn( + (parent_block_hash, curr_block_token_ids, None, extra_keys)) + assert block_hash_without_embeds == expected_without + assert block_hash_with_embeds != block_hash_without_embeds + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_hash_different_prompt_embeds(hash_fn: Callable[[Any], bytes]): + parent_block_hash = BlockHash(b"123") + curr_block_token_ids = (1, 2, 3) + prompt_embeds1 = torch.randn((2, 3)) + prompt_embeds2 = torch.randn((2, 3)) + + hash1 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, + None, prompt_embeds1) + hash2 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, + None, prompt_embeds2) + + assert hash1 != hash2 + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], + bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]), + block1_embeds_bytes, None)) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data( + prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]), + block2_embeds_bytes, None)) + assert block_hashes[1] == expected_hash2 + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], + bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], + mm_hashes=["hash1", "hash2"], + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]), + block1_embeds_bytes, ("hash1", ))) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data( + prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]), + block2_embeds_bytes, ("hash2", ))) + assert block_hashes[1] == expected_hash2 + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_with_prompt_embeds_no_mm(hash_fn: Callable[[Any], bytes]): + """Test request with prompt embeddings but no multimodal inputs""" + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + # Verify hashes include prompt embeddings but no mm keys + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]), + block1_embeds_bytes, None)) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data( + prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]), + block2_embeds_bytes, None)) + assert block_hashes[1] == expected_hash2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 93ad4d8080e6..9a27d73c875c 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -907,14 +907,14 @@ def test_mm_prefix_caching(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), None, ("aaa", ))) assert block_hashes[1] == sha256( (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), - ("aaa", "bbb"))) + None, ("aaa", "bbb"))) assert block_hashes[2] == sha256( (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), - ("bbb", ))) + None, ("bbb", ))) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, @@ -933,7 +933,7 @@ def test_mm_prefix_caching(): assert len(block_hashes) == 4 assert block_hashes[3] == sha256( (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), - ("ccc", ))) + None, ("ccc", ))) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -977,12 +977,14 @@ def test_cache_key_salting(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None, + ("salt1", ))) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None, + None)) assert block_hashes[2] == sha256( (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + None, None)) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, @@ -1000,7 +1002,8 @@ def test_cache_key_salting(): assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) + (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None, + None)) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 @@ -1019,12 +1022,14 @@ def test_cache_key_salting(): block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None, + ("salt2", ))) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None, + None)) assert block_hashes[2] == sha256( (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + None, None)) def test_prefill_not_enough_free_blocks_with_computed_blocks(): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bf293a4d2aa9..c724a4aedfdf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1552,16 +1552,6 @@ def _set_default_args(self, usage_context: UsageContext, if model_config.runner_type != "pooling": self.enable_chunked_prefill = True - # TODO: When prefix caching supports prompt embeds inputs, this - # check can be removed. - if (self.enable_prompt_embeds - and self.enable_prefix_caching is not False): - logger.warning( - "--enable-prompt-embeds and --enable-prefix-caching " - "are not supported together in V1. Prefix caching has " - "been disabled.") - self.enable_prefix_caching = False - if self.enable_prefix_caching is None: self.enable_prefix_caching = True else: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2ff1bb681d80..bacec96d3914 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -9,6 +9,8 @@ from dataclasses import dataclass from typing import Any, Callable, NewType, Optional, Union +import torch + from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger @@ -20,6 +22,7 @@ UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request +from vllm.v1.utils import tensor_data # BlockHash represents the hash of a single KV-cache block used for # prefix caching. Treating it as a distinct type from ``bytes`` helps @@ -545,10 +548,12 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable[[Any], bytes], - parent_block_hash: Optional[BlockHash], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, + curr_block_prompt_embeds: Optional[torch.Tensor] = None, +) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -560,6 +565,7 @@ def hash_block_tokens( curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. extra_keys: Extra keys for the block. + curr_block_prompt_embeds: The prompt embeddings of the current block. Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. @@ -568,9 +574,12 @@ def hash_block_tokens( parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids) + curr_block_prompt_embeds_bytes = ( + None if curr_block_prompt_embeds is None else + tensor_data(curr_block_prompt_embeds).tobytes()) return BlockHash( - hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) + hash_function((parent_block_hash, curr_block_token_ids_tuple, + curr_block_prompt_embeds_bytes, extra_keys))) def get_request_block_hasher( @@ -608,9 +617,12 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + block_prompt_embeds = ( + None if request.prompt_embeds is None else + request.prompt_embeds[start_token_idx:end_token_idx]) block_hash = hash_block_tokens(caching_hash_fn, prev_block_hash_value, block_tokens, - extra_keys) + extra_keys, block_prompt_embeds) new_block_hashes.append(block_hash) start_token_idx += block_size diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 876838084b9a..ebd3363f73c9 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -28,6 +28,7 @@ MultiModalSharedField, NestedTensors) # yapf: enable from vllm.v1.engine import UtilityResult +from vllm.v1.utils import tensor_data logger = init_logger(__name__) @@ -208,14 +209,14 @@ def _encode_tensor( ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().view(torch.uint8).numpy() + arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. - data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) else: # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) - self.aux_buffers.append(arr.data) + self.aux_buffers.append(arr_data) dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index ee0c1168f3cd..269608eb06b4 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -394,3 +394,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: _PROFILER_FUNC = func return func(name) + + +def tensor_data(tensor: torch.Tensor) -> memoryview: + return tensor.flatten().contiguous().view(torch.uint8).numpy().data