Skip to content
4 changes: 2 additions & 2 deletions docs/features/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)<sup>^</sup> | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](https://github.com/vllm-project/vllm/issues/25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
| [prompt-embeds](prompt_embeds.md) | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |

\* Chunked prefill and prefix caching are only applicable to last-token pooling.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
Expand All @@ -75,4 +75,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
137 changes: 136 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from collections.abc import Callable
from typing import Any

import pytest
import torch
Expand Down Expand Up @@ -32,6 +33,7 @@
init_none_hash,
is_kv_cache_spec_uniform,
make_block_hash_with_group_id,
tensor_data,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
Expand Down Expand Up @@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):

def make_request(
request_id: str,
prompt_token_ids: list[int],
prompt_token_ids: list[int] | None,
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: list[PlaceholderRange] | None = None,
mm_hashes: list[str] | None = None,
cache_salt: str | None = None,
prompt_embeds: torch.Tensor | None = None,
):
mm_features = []
if mm_positions is not None:
Expand All @@ -90,6 +93,7 @@ def make_request(
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn),
prompt_embeds=prompt_embeds,
)


Expand Down Expand Up @@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1


def test_generate_block_hash_extra_keys_prompt_embeds():
prompt_embeds = torch.randn(10, 3)
request = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds,
)

# Test with prompt embeds for the first block
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
expected_embeds = prompt_embeds[0:5]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)

# Test with prompt embeds for the second block
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
expected_embeds = prompt_embeds[5:10]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)


def test_generate_block_hash_extra_keys_different_prompt_embeds():
prompt_embeds1 = torch.randn(10, 3)
prompt_embeds2 = torch.randn(10, 3)
request1 = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds1,
)
request2 = make_request(
request_id="1",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds2,
)

extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0)
extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0)
assert extra_keys1 != extra_keys2


def test_generate_block_hash_extra_keys_lora():
request = make_request(
request_id="0",
Expand Down Expand Up @@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))

request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
prompt_embeds=prompt_embeds,
)

block_hashes = request.block_hashes
assert len(block_hashes) == 2

block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
(block1_embeds_bytes,),
)
)
assert block_hashes[0] == expected_hash1

block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
(block2_embeds_bytes,),
)
)
assert block_hashes[1] == expected_hash2


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))

request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
],
mm_hashes=["hash1", "hash2"],
prompt_embeds=prompt_embeds,
)

block_hashes = request.block_hashes
assert len(block_hashes) == 2

block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
("hash1", block1_embeds_bytes),
)
)
assert block_hashes[0] == expected_hash1

block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
("hash2", block2_embeds_bytes),
)
)
assert block_hashes[1] == expected_hash2
10 changes: 0 additions & 10 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,16 +1743,6 @@ def _set_default_args(
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True

# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled."
)
self.enable_prefix_caching = False

if self.enable_prefix_caching is None:
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
Expand Down
32 changes: 30 additions & 2 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
UniformTypeKVCacheSpecs,
)
from vllm.v1.request import Request
from vllm.v1.utils import tensor_data

# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from `bytes` helps
Expand Down Expand Up @@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
return [request.lora_request.lora_name]


def _gen_prompt_embeds_extra_hash_keys(
request: Request, start_token_idx: int, end_token_idx: int
) -> list[bytes]:
"""Generate extra keys related to prompt embeds for block hash computation.

Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.

Returns:
Return prompt embeddings data of the request if it has prompt embeds.
Return empty list otherwise.
"""
if request.prompt_embeds is None:
return []
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
return [embeds_bytes]


def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA name).
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings.

Args:
request: The request object.
Expand All @@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
cache_salt_keys: list[str] = (
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
)
prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys(
request, start_token_idx, end_token_idx
)

extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
extra_keys: list[Any] = (
lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
)

if not extra_keys:
return None, new_start_mm_idx
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
NestedTensors,
)
from vllm.v1.engine import UtilityResult
from vllm.v1.utils import tensor_data

logger = init_logger(__name__)

Expand Down Expand Up @@ -218,14 +219,14 @@ def _encode_tensor(
) -> tuple[str, tuple[int, ...], int | memoryview]:
assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
arr_data = tensor_data(obj)
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
self.aux_buffers.append(arr_data)
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data

Expand Down
13 changes: 13 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:

_PROFILER_FUNC = func
return func(name)


def tensor_data(tensor: torch.Tensor) -> memoryview:
"""Get the raw data of a tensor as a uint8 memoryview, useful for
serializing and hashing.

Args:
tensor: The input tensor.

Returns:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data