diff --git a/docs/features/README.md b/docs/features/README.md index 7faec0dc84f3..ad9de9ff8f36 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -52,7 +52,7 @@ th:not(:first-child) { | [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | | best-of | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | | | beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | -| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](https://github.com/vllm-project/vllm/issues/25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | \* Chunked prefill and prefix caching are only applicable to last-token pooling. ^ LoRA is only applicable to the language backbone of multimodal models. @@ -75,4 +75,4 @@ th:not(:first-child) { | multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ | | best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | | beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | -| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ | diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d192c58a8c15..df6a5f109874 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib from collections.abc import Callable +from typing import Any import pytest import torch @@ -32,6 +33,7 @@ init_none_hash, is_kv_cache_spec_uniform, make_block_hash_with_group_id, + tensor_data, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, @@ -61,12 +63,13 @@ def _auto_init_hash_fn(request): def make_request( request_id: str, - prompt_token_ids: list[int], + prompt_token_ids: list[int] | None, block_size: int = 3, hash_fn: Callable = hash, mm_positions: list[PlaceholderRange] | None = None, mm_hashes: list[str] | None = None, cache_salt: str | None = None, + prompt_embeds: torch.Tensor | None = None, ): mm_features = [] if mm_positions is not None: @@ -90,6 +93,7 @@ def make_request( lora_request=None, cache_salt=cache_salt, block_hasher=get_request_block_hasher(block_size, hash_fn), + prompt_embeds=prompt_embeds, ) @@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 +def test_generate_block_hash_extra_keys_prompt_embeds(): + prompt_embeds = torch.randn(10, 3) + request = make_request( + request_id="0", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds, + ) + + # Test with prompt embeds for the first block + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0) + expected_embeds = prompt_embeds[0:5] + expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() + assert extra_keys == (expected_bytes,) + + # Test with prompt embeds for the second block + extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0) + expected_embeds = prompt_embeds[5:10] + expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() + assert extra_keys == (expected_bytes,) + + +def test_generate_block_hash_extra_keys_different_prompt_embeds(): + prompt_embeds1 = torch.randn(10, 3) + prompt_embeds2 = torch.randn(10, 3) + request1 = make_request( + request_id="0", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds1, + ) + request2 = make_request( + request_id="1", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds2, + ) + + extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0) + extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0) + assert extra_keys1 != extra_keys2 + + def test_generate_block_hash_extra_keys_lora(): request = make_request( request_id="0", @@ -1556,3 +1606,88 @@ def test_merge_mla_spec(): ] with pytest.raises(AssertionError): kv_cache_specs[0].merge(kv_cache_specs) + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + ( + kv_cache_utils.NONE_HASH, + tuple(prompt_token_ids[:block_size]), + (block1_embeds_bytes,), + ) + ) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + ( + block_hashes[0], + tuple(prompt_token_ids[block_size:num_tokens]), + (block2_embeds_bytes,), + ) + ) + assert block_hashes[1] == expected_hash2 + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], + mm_hashes=["hash1", "hash2"], + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + ( + kv_cache_utils.NONE_HASH, + tuple(prompt_token_ids[:block_size]), + ("hash1", block1_embeds_bytes), + ) + ) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + ( + block_hashes[0], + tuple(prompt_token_ids[block_size:num_tokens]), + ("hash2", block2_embeds_bytes), + ) + ) + assert block_hashes[1] == expected_hash2 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ce41c377e457..3fdc0befecef 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1743,16 +1743,6 @@ def _set_default_args( if model_config.runner_type != "pooling": self.enable_chunked_prefill = True - # TODO: When prefix caching supports prompt embeds inputs, this - # check can be removed. - if self.enable_prompt_embeds and self.enable_prefix_caching is not False: - logger.warning( - "--enable-prompt-embeds and --enable-prefix-caching " - "are not supported together in V1. Prefix caching has " - "been disabled." - ) - self.enable_prefix_caching = False - if self.enable_prefix_caching is None: # Disable prefix caching default for hybrid models # since the feature is still experimental. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3b2ce8f54229..334c6393a8c9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -26,6 +26,7 @@ UniformTypeKVCacheSpecs, ) from vllm.v1.request import Request +from vllm.v1.utils import tensor_data # BlockHash represents the hash of a single KV-cache block used for # prefix caching. Treating it as a distinct type from `bytes` helps @@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]: return [request.lora_request.lora_name] +def _gen_prompt_embeds_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int +) -> list[bytes]: + """Generate extra keys related to prompt embeds for block hash computation. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + + Returns: + Return prompt embeddings data of the request if it has prompt embeds. + Return empty list otherwise. + """ + if request.prompt_embeds is None: + return [] + block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx] + embeds_bytes = tensor_data(block_prompt_embeds).tobytes() + return [embeds_bytes] + + def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int ) -> tuple[tuple[Any, ...] | None, int]: """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA name). + the multi-modal inputs, request specific metadata (e.g., LoRA names), and + data from prompt embeddings. Args: request: The request object. @@ -484,8 +507,13 @@ def generate_block_hash_extra_keys( cache_salt_keys: list[str] = ( [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] ) + prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys( + request, start_token_idx, end_token_idx + ) - extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys + extra_keys: list[Any] = ( + lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys + ) if not extra_keys: return None, new_start_mm_idx diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 528c9671dbfd..39147a67d6cf 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -31,6 +31,7 @@ NestedTensors, ) from vllm.v1.engine import UtilityResult +from vllm.v1.utils import tensor_data logger = init_logger(__name__) @@ -218,14 +219,14 @@ def _encode_tensor( ) -> tuple[str, tuple[int, ...], int | memoryview]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().view(torch.uint8).numpy() + arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. - data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) else: # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) - self.aux_buffers.append(arr.data) + self.aux_buffers.append(arr_data) dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e8fa81266469..789a74cc6c4a 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: _PROFILER_FUNC = func return func(name) + + +def tensor_data(tensor: torch.Tensor) -> memoryview: + """Get the raw data of a tensor as a uint8 memoryview, useful for + serializing and hashing. + + Args: + tensor: The input tensor. + + Returns: + A memoryview of the tensor data as uint8. + """ + return tensor.flatten().contiguous().view(torch.uint8).numpy().data