diff --git a/docs/features/README.md b/docs/features/README.md
index 7faec0dc84f3..ad9de9ff8f36 100644
--- a/docs/features/README.md
+++ b/docs/features/README.md
@@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
-| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](https://github.com/vllm-project/vllm/issues/25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
+| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
^ LoRA is only applicable to the language backbone of multimodal models.
@@ -75,4 +75,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
-| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
+| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index d192c58a8c15..df6a5f109874 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
+from typing import Any
import pytest
import torch
@@ -32,6 +33,7 @@
init_none_hash,
is_kv_cache_spec_uniform,
make_block_hash_with_group_id,
+ tensor_data,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
def make_request(
request_id: str,
- prompt_token_ids: list[int],
+ prompt_token_ids: list[int] | None,
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: list[PlaceholderRange] | None = None,
mm_hashes: list[str] | None = None,
cache_salt: str | None = None,
+ prompt_embeds: torch.Tensor | None = None,
):
mm_features = []
if mm_positions is not None:
@@ -90,6 +93,7 @@ def make_request(
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn),
+ prompt_embeds=prompt_embeds,
)
@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1
+def test_generate_block_hash_extra_keys_prompt_embeds():
+ prompt_embeds = torch.randn(10, 3)
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=None,
+ mm_positions=None,
+ mm_hashes=None,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # Test with prompt embeds for the first block
+ extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
+ expected_embeds = prompt_embeds[0:5]
+ expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
+ assert extra_keys == (expected_bytes,)
+
+ # Test with prompt embeds for the second block
+ extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
+ expected_embeds = prompt_embeds[5:10]
+ expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
+ assert extra_keys == (expected_bytes,)
+
+
+def test_generate_block_hash_extra_keys_different_prompt_embeds():
+ prompt_embeds1 = torch.randn(10, 3)
+ prompt_embeds2 = torch.randn(10, 3)
+ request1 = make_request(
+ request_id="0",
+ prompt_token_ids=None,
+ mm_positions=None,
+ mm_hashes=None,
+ prompt_embeds=prompt_embeds1,
+ )
+ request2 = make_request(
+ request_id="1",
+ prompt_token_ids=None,
+ mm_positions=None,
+ mm_hashes=None,
+ prompt_embeds=prompt_embeds2,
+ )
+
+ extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0)
+ extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0)
+ assert extra_keys1 != extra_keys2
+
+
def test_generate_block_hash_extra_keys_lora():
request = make_request(
request_id="0",
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
+ block_size = 3
+ num_tokens = 2 * block_size
+ prompt_token_ids = [_ for _ in range(num_tokens)]
+ hidden_size = 5
+ prompt_embeds = torch.randn((num_tokens, hidden_size))
+
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=prompt_token_ids,
+ block_size=block_size,
+ hash_fn=hash_fn,
+ prompt_embeds=prompt_embeds,
+ )
+
+ block_hashes = request.block_hashes
+ assert len(block_hashes) == 2
+
+ block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
+ expected_hash1 = hash_fn(
+ (
+ kv_cache_utils.NONE_HASH,
+ tuple(prompt_token_ids[:block_size]),
+ (block1_embeds_bytes,),
+ )
+ )
+ assert block_hashes[0] == expected_hash1
+
+ block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
+ expected_hash2 = hash_fn(
+ (
+ block_hashes[0],
+ tuple(prompt_token_ids[block_size:num_tokens]),
+ (block2_embeds_bytes,),
+ )
+ )
+ assert block_hashes[1] == expected_hash2
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]):
+ block_size = 3
+ num_tokens = 2 * block_size
+ prompt_token_ids = [_ for _ in range(num_tokens)]
+ hidden_size = 5
+ prompt_embeds = torch.randn((num_tokens, hidden_size))
+
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=prompt_token_ids,
+ block_size=block_size,
+ hash_fn=hash_fn,
+ mm_positions=[
+ PlaceholderRange(offset=0, length=3),
+ PlaceholderRange(offset=3, length=3),
+ ],
+ mm_hashes=["hash1", "hash2"],
+ prompt_embeds=prompt_embeds,
+ )
+
+ block_hashes = request.block_hashes
+ assert len(block_hashes) == 2
+
+ block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
+ expected_hash1 = hash_fn(
+ (
+ kv_cache_utils.NONE_HASH,
+ tuple(prompt_token_ids[:block_size]),
+ ("hash1", block1_embeds_bytes),
+ )
+ )
+ assert block_hashes[0] == expected_hash1
+
+ block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
+ expected_hash2 = hash_fn(
+ (
+ block_hashes[0],
+ tuple(prompt_token_ids[block_size:num_tokens]),
+ ("hash2", block2_embeds_bytes),
+ )
+ )
+ assert block_hashes[1] == expected_hash2
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index ce41c377e457..3fdc0befecef 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1743,16 +1743,6 @@ def _set_default_args(
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
- # TODO: When prefix caching supports prompt embeds inputs, this
- # check can be removed.
- if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
- logger.warning(
- "--enable-prompt-embeds and --enable-prefix-caching "
- "are not supported together in V1. Prefix caching has "
- "been disabled."
- )
- self.enable_prefix_caching = False
-
if self.enable_prefix_caching is None:
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index 3b2ce8f54229..334c6393a8c9 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -26,6 +26,7 @@
UniformTypeKVCacheSpecs,
)
from vllm.v1.request import Request
+from vllm.v1.utils import tensor_data
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from `bytes` helps
@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
return [request.lora_request.lora_name]
+def _gen_prompt_embeds_extra_hash_keys(
+ request: Request, start_token_idx: int, end_token_idx: int
+) -> list[bytes]:
+ """Generate extra keys related to prompt embeds for block hash computation.
+
+ Args:
+ request: The request object.
+ start_token_idx: The start token index of the block.
+ end_token_idx: The end token index of the block.
+
+ Returns:
+ Return prompt embeddings data of the request if it has prompt embeds.
+ Return empty list otherwise.
+ """
+ if request.prompt_embeds is None:
+ return []
+ block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
+ embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
+ return [embeds_bytes]
+
+
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from
- the multi-modal inputs and request specific metadata (e.g., LoRA name).
+ the multi-modal inputs, request specific metadata (e.g., LoRA names), and
+ data from prompt embeddings.
Args:
request: The request object.
@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
cache_salt_keys: list[str] = (
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
)
+ prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys(
+ request, start_token_idx, end_token_idx
+ )
- extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
+ extra_keys: list[Any] = (
+ lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
+ )
if not extra_keys:
return None, new_start_mm_idx
diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py
index 528c9671dbfd..39147a67d6cf 100644
--- a/vllm/v1/serial_utils.py
+++ b/vllm/v1/serial_utils.py
@@ -31,6 +31,7 @@
NestedTensors,
)
from vllm.v1.engine import UtilityResult
+from vllm.v1.utils import tensor_data
logger = init_logger(__name__)
@@ -218,14 +219,14 @@ def _encode_tensor(
) -> tuple[str, tuple[int, ...], int | memoryview]:
assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes
- arr = obj.flatten().contiguous().view(torch.uint8).numpy()
+ arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
- data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
+ data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
- self.aux_buffers.append(arr.data)
+ self.aux_buffers.append(arr_data)
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py
index e8fa81266469..789a74cc6c4a 100644
--- a/vllm/v1/utils.py
+++ b/vllm/v1/utils.py
@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
_PROFILER_FUNC = func
return func(name)
+
+
+def tensor_data(tensor: torch.Tensor) -> memoryview:
+ """Get the raw data of a tensor as a uint8 memoryview, useful for
+ serializing and hashing.
+
+ Args:
+ tensor: The input tensor.
+
+ Returns:
+ A memoryview of the tensor data as uint8.
+ """
+ return tensor.flatten().contiguous().view(torch.uint8).numpy().data