Skip to content

Commit ff93cc8

Browse files
[CORE] Support Prefix Caching with Prompt Embeds (#27219)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
1 parent 243ed7d commit ff93cc8

File tree

6 files changed

+185
-18
lines changed

6 files changed

+185
-18
lines changed

docs/features/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ th:not(:first-child) {
5252
| [mm](multimodal_inputs.md) ||| [🟠](https://github.com/vllm-project/vllm/pull/4194)<sup>^</sup> |||||||||| | | |
5353
| best-of |||| [](https://github.com/vllm-project/vllm/issues/6137) ||||||| [](https://github.com/vllm-project/vllm/issues/7968) ||| | |
5454
| beam-search |||| [](https://github.com/vllm-project/vllm/issues/6137) ||||||| [](https://github.com/vllm-project/vllm/issues/7968) |||| |
55-
| [prompt-embeds](prompt_embeds.md) || [](https://github.com/vllm-project/vllm/issues/25096) ||||||||||||||
55+
| [prompt-embeds](prompt_embeds.md) || ||||||||||||||
5656

5757
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
5858
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
@@ -75,4 +75,4 @@ th:not(:first-child) {
7575
| multi-step |||||| [](https://github.com/vllm-project/vllm/issues/8477) ||||
7676
| best-of ||||||||||
7777
| beam-search ||||||||||
78-
| [prompt-embeds](prompt_embeds.md) ||||||| ? | [](https://github.com/vllm-project/vllm/issues/25097) ||
78+
| [prompt-embeds](prompt_embeds.md) ||||||| | [](https://github.com/vllm-project/vllm/issues/25097) ||

tests/v1/core/test_kv_cache_utils.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import importlib
44
from collections.abc import Callable
5+
from typing import Any
56

67
import pytest
78
import torch
@@ -32,6 +33,7 @@
3233
init_none_hash,
3334
is_kv_cache_spec_uniform,
3435
make_block_hash_with_group_id,
36+
tensor_data,
3537
)
3638
from vllm.v1.kv_cache_interface import (
3739
FullAttentionSpec,
@@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
6163

6264
def make_request(
6365
request_id: str,
64-
prompt_token_ids: list[int],
66+
prompt_token_ids: list[int] | None,
6567
block_size: int = 3,
6668
hash_fn: Callable = hash,
6769
mm_positions: list[PlaceholderRange] | None = None,
6870
mm_hashes: list[str] | None = None,
6971
cache_salt: str | None = None,
72+
prompt_embeds: torch.Tensor | None = None,
7073
):
7174
mm_features = []
7275
if mm_positions is not None:
@@ -90,6 +93,7 @@ def make_request(
9093
lora_request=None,
9194
cache_salt=cache_salt,
9295
block_hasher=get_request_block_hasher(block_size, hash_fn),
96+
prompt_embeds=prompt_embeds,
9397
)
9498

9599

@@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
450454
assert next_mm_idx == 1
451455

452456

457+
def test_generate_block_hash_extra_keys_prompt_embeds():
458+
prompt_embeds = torch.randn(10, 3)
459+
request = make_request(
460+
request_id="0",
461+
prompt_token_ids=None,
462+
mm_positions=None,
463+
mm_hashes=None,
464+
prompt_embeds=prompt_embeds,
465+
)
466+
467+
# Test with prompt embeds for the first block
468+
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
469+
expected_embeds = prompt_embeds[0:5]
470+
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
471+
assert extra_keys == (expected_bytes,)
472+
473+
# Test with prompt embeds for the second block
474+
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
475+
expected_embeds = prompt_embeds[5:10]
476+
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
477+
assert extra_keys == (expected_bytes,)
478+
479+
480+
def test_generate_block_hash_extra_keys_different_prompt_embeds():
481+
prompt_embeds1 = torch.randn(10, 3)
482+
prompt_embeds2 = torch.randn(10, 3)
483+
request1 = make_request(
484+
request_id="0",
485+
prompt_token_ids=None,
486+
mm_positions=None,
487+
mm_hashes=None,
488+
prompt_embeds=prompt_embeds1,
489+
)
490+
request2 = make_request(
491+
request_id="1",
492+
prompt_token_ids=None,
493+
mm_positions=None,
494+
mm_hashes=None,
495+
prompt_embeds=prompt_embeds2,
496+
)
497+
498+
extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0)
499+
extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0)
500+
assert extra_keys1 != extra_keys2
501+
502+
453503
def test_generate_block_hash_extra_keys_lora():
454504
request = make_request(
455505
request_id="0",
@@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
15561606
]
15571607
with pytest.raises(AssertionError):
15581608
kv_cache_specs[0].merge(kv_cache_specs)
1609+
1610+
1611+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
1612+
def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
1613+
block_size = 3
1614+
num_tokens = 2 * block_size
1615+
prompt_token_ids = [_ for _ in range(num_tokens)]
1616+
hidden_size = 5
1617+
prompt_embeds = torch.randn((num_tokens, hidden_size))
1618+
1619+
request = make_request(
1620+
request_id="0",
1621+
prompt_token_ids=prompt_token_ids,
1622+
block_size=block_size,
1623+
hash_fn=hash_fn,
1624+
prompt_embeds=prompt_embeds,
1625+
)
1626+
1627+
block_hashes = request.block_hashes
1628+
assert len(block_hashes) == 2
1629+
1630+
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
1631+
expected_hash1 = hash_fn(
1632+
(
1633+
kv_cache_utils.NONE_HASH,
1634+
tuple(prompt_token_ids[:block_size]),
1635+
(block1_embeds_bytes,),
1636+
)
1637+
)
1638+
assert block_hashes[0] == expected_hash1
1639+
1640+
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
1641+
expected_hash2 = hash_fn(
1642+
(
1643+
block_hashes[0],
1644+
tuple(prompt_token_ids[block_size:num_tokens]),
1645+
(block2_embeds_bytes,),
1646+
)
1647+
)
1648+
assert block_hashes[1] == expected_hash2
1649+
1650+
1651+
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
1652+
def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]):
1653+
block_size = 3
1654+
num_tokens = 2 * block_size
1655+
prompt_token_ids = [_ for _ in range(num_tokens)]
1656+
hidden_size = 5
1657+
prompt_embeds = torch.randn((num_tokens, hidden_size))
1658+
1659+
request = make_request(
1660+
request_id="0",
1661+
prompt_token_ids=prompt_token_ids,
1662+
block_size=block_size,
1663+
hash_fn=hash_fn,
1664+
mm_positions=[
1665+
PlaceholderRange(offset=0, length=3),
1666+
PlaceholderRange(offset=3, length=3),
1667+
],
1668+
mm_hashes=["hash1", "hash2"],
1669+
prompt_embeds=prompt_embeds,
1670+
)
1671+
1672+
block_hashes = request.block_hashes
1673+
assert len(block_hashes) == 2
1674+
1675+
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
1676+
expected_hash1 = hash_fn(
1677+
(
1678+
kv_cache_utils.NONE_HASH,
1679+
tuple(prompt_token_ids[:block_size]),
1680+
("hash1", block1_embeds_bytes),
1681+
)
1682+
)
1683+
assert block_hashes[0] == expected_hash1
1684+
1685+
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
1686+
expected_hash2 = hash_fn(
1687+
(
1688+
block_hashes[0],
1689+
tuple(prompt_token_ids[block_size:num_tokens]),
1690+
("hash2", block2_embeds_bytes),
1691+
)
1692+
)
1693+
assert block_hashes[1] == expected_hash2

vllm/engine/arg_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,16 +1743,6 @@ def _set_default_args(
17431743
if model_config.runner_type != "pooling":
17441744
self.enable_chunked_prefill = True
17451745

1746-
# TODO: When prefix caching supports prompt embeds inputs, this
1747-
# check can be removed.
1748-
if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
1749-
logger.warning(
1750-
"--enable-prompt-embeds and --enable-prefix-caching "
1751-
"are not supported together in V1. Prefix caching has "
1752-
"been disabled."
1753-
)
1754-
self.enable_prefix_caching = False
1755-
17561746
if self.enable_prefix_caching is None:
17571747
# Disable prefix caching default for hybrid models
17581748
# since the feature is still experimental.

vllm/v1/core/kv_cache_utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
UniformTypeKVCacheSpecs,
2727
)
2828
from vllm.v1.request import Request
29+
from vllm.v1.utils import tensor_data
2930

3031
# BlockHash represents the hash of a single KV-cache block used for
3132
# prefix caching. Treating it as a distinct type from `bytes` helps
@@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
461462
return [request.lora_request.lora_name]
462463

463464

465+
def _gen_prompt_embeds_extra_hash_keys(
466+
request: Request, start_token_idx: int, end_token_idx: int
467+
) -> list[bytes]:
468+
"""Generate extra keys related to prompt embeds for block hash computation.
469+
470+
Args:
471+
request: The request object.
472+
start_token_idx: The start token index of the block.
473+
end_token_idx: The end token index of the block.
474+
475+
Returns:
476+
Return prompt embeddings data of the request if it has prompt embeds.
477+
Return empty list otherwise.
478+
"""
479+
if request.prompt_embeds is None:
480+
return []
481+
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
482+
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
483+
return [embeds_bytes]
484+
485+
464486
def generate_block_hash_extra_keys(
465487
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
466488
) -> tuple[tuple[Any, ...] | None, int]:
467489
"""Generate extra keys for the block hash. The extra keys can come from
468-
the multi-modal inputs and request specific metadata (e.g., LoRA name).
490+
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
491+
data from prompt embeddings.
469492
470493
Args:
471494
request: The request object.
@@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
484507
cache_salt_keys: list[str] = (
485508
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
486509
)
510+
prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys(
511+
request, start_token_idx, end_token_idx
512+
)
487513

488-
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
514+
extra_keys: list[Any] = (
515+
lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
516+
)
489517

490518
if not extra_keys:
491519
return None, new_start_mm_idx

vllm/v1/serial_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
NestedTensors,
3232
)
3333
from vllm.v1.engine import UtilityResult
34+
from vllm.v1.utils import tensor_data
3435

3536
logger = init_logger(__name__)
3637

@@ -218,14 +219,14 @@ def _encode_tensor(
218219
) -> tuple[str, tuple[int, ...], int | memoryview]:
219220
assert self.aux_buffers is not None
220221
# view the tensor as a contiguous 1D array of bytes
221-
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
222+
arr_data = tensor_data(obj)
222223
if obj.nbytes < self.size_threshold:
223224
# Smaller tensors are encoded inline, just like ndarrays.
224-
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
225+
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
225226
else:
226227
# Otherwise encode index of backing buffer to avoid copy.
227228
data = len(self.aux_buffers)
228-
self.aux_buffers.append(arr.data)
229+
self.aux_buffers.append(arr_data)
229230
dtype = str(obj.dtype).removeprefix("torch.")
230231
return dtype, obj.shape, data
231232

vllm/v1/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
396396

397397
_PROFILER_FUNC = func
398398
return func(name)
399+
400+
401+
def tensor_data(tensor: torch.Tensor) -> memoryview:
402+
"""Get the raw data of a tensor as a uint8 memoryview, useful for
403+
serializing and hashing.
404+
405+
Args:
406+
tensor: The input tensor.
407+
408+
Returns:
409+
A memoryview of the tensor data as uint8.
410+
"""
411+
return tensor.flatten().contiguous().view(torch.uint8).numpy().data

0 commit comments

Comments
 (0)