diff --git a/docs/features/README.md b/docs/features/README.md
index 05ce0b57a9fc..537b5862e39f 100644
--- a/docs/features/README.md
+++ b/docs/features/README.md
@@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | |
-| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
+| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
^ LoRA is only applicable to the language backbone of multimodal models.
@@ -77,4 +77,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
-| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](gh-issue:25097) |
+| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:25097) |
diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py
index 09f43a793db2..f7544da6a3a7 100644
--- a/tests/v1/core/test_kv_cache_utils.py
+++ b/tests/v1/core/test_kv_cache_utils.py
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
import pytest
import torch
@@ -29,6 +29,7 @@
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
+from vllm.v1.utils import tensor_data
# yapf: enable
@@ -47,12 +48,13 @@ def _auto_init_hash_fn(request):
def make_request(
request_id: str,
- prompt_token_ids: list[int],
+ prompt_token_ids: Optional[list[int]],
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
):
mm_features = []
if mm_positions is not None:
@@ -72,6 +74,7 @@ def make_request(
pooling_params=None,
eos_token_id=100,
lora_request=None,
+ prompt_embeds=prompt_embeds,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn))
@@ -441,7 +444,8 @@ def test_hash_block_tokens(hash_fn):
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
- expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
+ expected = hash_fn(
+ (parent_block_hash, curr_block_token_ids, None, extra_keys))
assert block_hash == expected
@@ -462,9 +466,9 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0] == hash_fn(
- (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
+ (kv_cache_utils.NONE_HASH, (0, 1, 2), None, ("hash1", )))
assert block_hashes[1] == hash_fn(
- (block_hashes[0], (3, 4, 5), ("hash2", )))
+ (block_hashes[0], (3, 4, 5), None, ("hash2", )))
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
@@ -510,8 +514,8 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
assert len(block_hashes) == 2
assert block_hashes[0] == hash_fn(
- (kv_cache_utils.NONE_HASH, (0, 1, 2), None))
- assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
+ (kv_cache_utils.NONE_HASH, (0, 1, 2), None, None))
+ assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None, None))
def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats:
@@ -1452,3 +1456,153 @@ def test_merge_mla_spec():
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_hash_block_tokens_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
+ parent_block_hash = BlockHash(b"123")
+ curr_block_token_ids = (1, 2, 3)
+ extra_keys = ("key1", "key2")
+ prompt_embeds = torch.randn((2, 3))
+
+ block_hash_with_embeds = hash_block_tokens(hash_fn, parent_block_hash,
+ curr_block_token_ids,
+ extra_keys, prompt_embeds)
+
+ prompt_embeds_bytes = tensor_data(prompt_embeds).tobytes()
+ expected = hash_fn((parent_block_hash, curr_block_token_ids,
+ prompt_embeds_bytes, extra_keys))
+ assert block_hash_with_embeds == expected
+
+ block_hash_without_embeds = hash_block_tokens(hash_fn, parent_block_hash,
+ curr_block_token_ids,
+ extra_keys, None)
+ expected_without = hash_fn(
+ (parent_block_hash, curr_block_token_ids, None, extra_keys))
+ assert block_hash_without_embeds == expected_without
+ assert block_hash_with_embeds != block_hash_without_embeds
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_hash_different_prompt_embeds(hash_fn: Callable[[Any], bytes]):
+ parent_block_hash = BlockHash(b"123")
+ curr_block_token_ids = (1, 2, 3)
+ prompt_embeds1 = torch.randn((2, 3))
+ prompt_embeds2 = torch.randn((2, 3))
+
+ hash1 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids,
+ None, prompt_embeds1)
+ hash2 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids,
+ None, prompt_embeds2)
+
+ assert hash1 != hash2
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any],
+ bytes]):
+ block_size = 3
+ num_tokens = 2 * block_size
+ prompt_token_ids = [_ for _ in range(num_tokens)]
+ hidden_size = 5
+ prompt_embeds = torch.randn((num_tokens, hidden_size))
+
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=prompt_token_ids,
+ block_size=block_size,
+ hash_fn=hash_fn,
+ prompt_embeds=prompt_embeds,
+ )
+
+ block_hashes = request.block_hashes
+ assert len(block_hashes) == 2
+
+ block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
+ expected_hash1 = hash_fn(
+ (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
+ block1_embeds_bytes, None))
+ assert block_hashes[0] == expected_hash1
+
+ block2_embeds_bytes = tensor_data(
+ prompt_embeds[block_size:num_tokens]).tobytes()
+ expected_hash2 = hash_fn(
+ (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
+ block2_embeds_bytes, None))
+ assert block_hashes[1] == expected_hash2
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any],
+ bytes]):
+ block_size = 3
+ num_tokens = 2 * block_size
+ prompt_token_ids = [_ for _ in range(num_tokens)]
+ hidden_size = 5
+ prompt_embeds = torch.randn((num_tokens, hidden_size))
+
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=prompt_token_ids,
+ block_size=block_size,
+ hash_fn=hash_fn,
+ mm_positions=[
+ PlaceholderRange(offset=0, length=3),
+ PlaceholderRange(offset=3, length=3),
+ ],
+ mm_hashes=["hash1", "hash2"],
+ prompt_embeds=prompt_embeds,
+ )
+
+ block_hashes = request.block_hashes
+ assert len(block_hashes) == 2
+
+ block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
+ expected_hash1 = hash_fn(
+ (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
+ block1_embeds_bytes, ("hash1", )))
+ assert block_hashes[0] == expected_hash1
+
+ block2_embeds_bytes = tensor_data(
+ prompt_embeds[block_size:num_tokens]).tobytes()
+ expected_hash2 = hash_fn(
+ (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
+ block2_embeds_bytes, ("hash2", )))
+ assert block_hashes[1] == expected_hash2
+
+
+@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
+def test_request_with_prompt_embeds_no_mm(hash_fn: Callable[[Any], bytes]):
+ """Test request with prompt embeddings but no multimodal inputs"""
+ block_size = 3
+ num_tokens = 2 * block_size
+ prompt_token_ids = [_ for _ in range(num_tokens)]
+ hidden_size = 5
+ prompt_embeds = torch.randn((num_tokens, hidden_size))
+
+ request = make_request(
+ request_id="0",
+ prompt_token_ids=prompt_token_ids,
+ block_size=block_size,
+ hash_fn=hash_fn,
+ mm_positions=None,
+ mm_hashes=None,
+ prompt_embeds=prompt_embeds,
+ )
+
+ block_hashes = request.block_hashes
+ assert len(block_hashes) == 2
+
+ # Verify hashes include prompt embeddings but no mm keys
+ block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
+ expected_hash1 = hash_fn(
+ (kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
+ block1_embeds_bytes, None))
+ assert block_hashes[0] == expected_hash1
+
+ block2_embeds_bytes = tensor_data(
+ prompt_embeds[block_size:num_tokens]).tobytes()
+ expected_hash2 = hash_fn(
+ (block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
+ block2_embeds_bytes, None))
+ assert block_hashes[1] == expected_hash2
diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py
index 93ad4d8080e6..9a27d73c875c 100644
--- a/tests/v1/core/test_prefix_caching.py
+++ b/tests/v1/core/test_prefix_caching.py
@@ -907,14 +907,14 @@ def test_mm_prefix_caching():
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
- (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
+ (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), None,
("aaa", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
- ("aaa", "bbb")))
+ None, ("aaa", "bbb")))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
- ("bbb", )))
+ None, ("bbb", )))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
@@ -933,7 +933,7 @@ def test_mm_prefix_caching():
assert len(block_hashes) == 4
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
- ("ccc", )))
+ None, ("ccc", )))
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
@@ -977,12 +977,14 @@ def test_cache_key_salting():
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
- (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
+ (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None,
+ ("salt1", )))
assert block_hashes[1] == sha256(
- (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
+ (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None,
+ None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
- None))
+ None, None))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
@@ -1000,7 +1002,8 @@ def test_cache_key_salting():
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
assert len(block_hashes) == 4
assert block_hashes[3] == sha256(
- (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
+ (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None,
+ None))
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
@@ -1019,12 +1022,14 @@ def test_cache_key_salting():
block_hashes = req2.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
- (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
+ (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None,
+ ("salt2", )))
assert block_hashes[1] == sha256(
- (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
+ (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None,
+ None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
- None))
+ None, None))
def test_prefill_not_enough_free_blocks_with_computed_blocks():
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index bf293a4d2aa9..c724a4aedfdf 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1552,16 +1552,6 @@ def _set_default_args(self, usage_context: UsageContext,
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
- # TODO: When prefix caching supports prompt embeds inputs, this
- # check can be removed.
- if (self.enable_prompt_embeds
- and self.enable_prefix_caching is not False):
- logger.warning(
- "--enable-prompt-embeds and --enable-prefix-caching "
- "are not supported together in V1. Prefix caching has "
- "been disabled.")
- self.enable_prefix_caching = False
-
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
index 2ff1bb681d80..bacec96d3914 100644
--- a/vllm/v1/core/kv_cache_utils.py
+++ b/vllm/v1/core/kv_cache_utils.py
@@ -9,6 +9,8 @@
from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union
+import torch
+
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
@@ -20,6 +22,7 @@
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
+from vllm.v1.utils import tensor_data
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from ``bytes`` helps
@@ -545,10 +548,12 @@ def generate_block_hash_extra_keys(
def hash_block_tokens(
- hash_function: Callable[[Any], bytes],
- parent_block_hash: Optional[BlockHash],
- curr_block_token_ids: Sequence[int],
- extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
+ hash_function: Callable[[Any], bytes],
+ parent_block_hash: Optional[BlockHash],
+ curr_block_token_ids: Sequence[int],
+ extra_keys: Optional[tuple[Any, ...]] = None,
+ curr_block_prompt_embeds: Optional[torch.Tensor] = None,
+) -> BlockHash:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
@@ -560,6 +565,7 @@ def hash_block_tokens(
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
+ curr_block_prompt_embeds: The prompt embeddings of the current block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
@@ -568,9 +574,12 @@ def hash_block_tokens(
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
+ curr_block_prompt_embeds_bytes = (
+ None if curr_block_prompt_embeds is None else
+ tensor_data(curr_block_prompt_embeds).tobytes())
return BlockHash(
- hash_function(
- (parent_block_hash, curr_block_token_ids_tuple, extra_keys)))
+ hash_function((parent_block_hash, curr_block_token_ids_tuple,
+ curr_block_prompt_embeds_bytes, extra_keys)))
def get_request_block_hasher(
@@ -608,9 +617,12 @@ def request_block_hasher(request: Request) -> list[BlockHash]:
# Compute the hash of the current block
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
+ block_prompt_embeds = (
+ None if request.prompt_embeds is None else
+ request.prompt_embeds[start_token_idx:end_token_idx])
block_hash = hash_block_tokens(caching_hash_fn,
prev_block_hash_value, block_tokens,
- extra_keys)
+ extra_keys, block_prompt_embeds)
new_block_hashes.append(block_hash)
start_token_idx += block_size
diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py
index 876838084b9a..ebd3363f73c9 100644
--- a/vllm/v1/serial_utils.py
+++ b/vllm/v1/serial_utils.py
@@ -28,6 +28,7 @@
MultiModalSharedField, NestedTensors)
# yapf: enable
from vllm.v1.engine import UtilityResult
+from vllm.v1.utils import tensor_data
logger = init_logger(__name__)
@@ -208,14 +209,14 @@ def _encode_tensor(
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes
- arr = obj.flatten().contiguous().view(torch.uint8).numpy()
+ arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
- data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
+ data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
- self.aux_buffers.append(arr.data)
+ self.aux_buffers.append(arr_data)
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data
diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py
index ee0c1168f3cd..269608eb06b4 100644
--- a/vllm/v1/utils.py
+++ b/vllm/v1/utils.py
@@ -394,3 +394,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
_PROFILER_FUNC = func
return func(name)
+
+
+def tensor_data(tensor: torch.Tensor) -> memoryview:
+ return tensor.flatten().contiguous().view(torch.uint8).numpy().data