Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/features/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)<sup>^</sup> | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
| [prompt-embeds](prompt_embeds.md) | ✅ | | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |

\* Chunked prefill and prefix caching are only applicable to last-token pooling.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
Expand All @@ -77,4 +77,4 @@ th:not(:first-child) {
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ |
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](gh-issue:25097) |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [❌](gh-issue:25097) |
168 changes: 161 additions & 7 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import Callable, Optional
from typing import Any, Callable, Optional

import pytest
import torch
Expand Down Expand Up @@ -29,6 +29,7 @@
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
from vllm.v1.utils import tensor_data

# yapf: enable

Expand All @@ -47,12 +48,13 @@ def _auto_init_hash_fn(request):

def make_request(
request_id: str,
prompt_token_ids: list[int],
prompt_token_ids: Optional[list[int]],
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
):
mm_features = []
if mm_positions is not None:
Expand All @@ -72,6 +74,7 @@ def make_request(
pooling_params=None,
eos_token_id=100,
lora_request=None,
prompt_embeds=prompt_embeds,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn))

Expand Down Expand Up @@ -441,7 +444,8 @@ def test_hash_block_tokens(hash_fn):

block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
expected = hash_fn(
(parent_block_hash, curr_block_token_ids, None, extra_keys))
assert block_hash == expected


Expand All @@ -462,9 +466,9 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
(kv_cache_utils.NONE_HASH, (0, 1, 2), None, ("hash1", )))
assert block_hashes[1] == hash_fn(
(block_hashes[0], (3, 4, 5), ("hash2", )))
(block_hashes[0], (3, 4, 5), None, ("hash2", )))


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
Expand Down Expand Up @@ -510,8 +514,8 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):

assert len(block_hashes) == 2
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
(kv_cache_utils.NONE_HASH, (0, 1, 2), None, None))
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None, None))


def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats:
Expand Down Expand Up @@ -1452,3 +1456,153 @@ def test_merge_mla_spec():
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
prompt_embeds = torch.randn((2, 3))

block_hash_with_embeds = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids,
extra_keys, prompt_embeds)

prompt_embeds_bytes = tensor_data(prompt_embeds).tobytes()
expected = hash_fn((parent_block_hash, curr_block_token_ids,
prompt_embeds_bytes, extra_keys))
assert block_hash_with_embeds == expected

block_hash_without_embeds = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids,
extra_keys, None)
expected_without = hash_fn(
(parent_block_hash, curr_block_token_ids, None, extra_keys))
assert block_hash_without_embeds == expected_without
assert block_hash_with_embeds != block_hash_without_embeds


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_different_prompt_embeds(hash_fn: Callable[[Any], bytes]):
parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3)
prompt_embeds1 = torch.randn((2, 3))
prompt_embeds2 = torch.randn((2, 3))

hash1 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids,
None, prompt_embeds1)
hash2 = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids,
None, prompt_embeds2)

assert hash1 != hash2


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any],
bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))

request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
prompt_embeds=prompt_embeds,
)

block_hashes = request.block_hashes
assert len(block_hashes) == 2

block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
block1_embeds_bytes, None))
assert block_hashes[0] == expected_hash1

block2_embeds_bytes = tensor_data(
prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
block2_embeds_bytes, None))
assert block_hashes[1] == expected_hash2


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any],
bytes]):
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))

request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
],
mm_hashes=["hash1", "hash2"],
prompt_embeds=prompt_embeds,
)

block_hashes = request.block_hashes
assert len(block_hashes) == 2

block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
block1_embeds_bytes, ("hash1", )))
assert block_hashes[0] == expected_hash1

block2_embeds_bytes = tensor_data(
prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
block2_embeds_bytes, ("hash2", )))
assert block_hashes[1] == expected_hash2


@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_with_prompt_embeds_no_mm(hash_fn: Callable[[Any], bytes]):
"""Test request with prompt embeddings but no multimodal inputs"""
block_size = 3
num_tokens = 2 * block_size
prompt_token_ids = [_ for _ in range(num_tokens)]
hidden_size = 5
prompt_embeds = torch.randn((num_tokens, hidden_size))

request = make_request(
request_id="0",
prompt_token_ids=prompt_token_ids,
block_size=block_size,
hash_fn=hash_fn,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds,
)

block_hashes = request.block_hashes
assert len(block_hashes) == 2

# Verify hashes include prompt embeddings but no mm keys
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
expected_hash1 = hash_fn(
(kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]),
block1_embeds_bytes, None))
assert block_hashes[0] == expected_hash1

block2_embeds_bytes = tensor_data(
prompt_embeds[block_size:num_tokens]).tobytes()
expected_hash2 = hash_fn(
(block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]),
block2_embeds_bytes, None))
assert block_hashes[1] == expected_hash2
27 changes: 16 additions & 11 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,14 +907,14 @@ def test_mm_prefix_caching():
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), None,
("aaa", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
("aaa", "bbb")))
None, ("aaa", "bbb")))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
("bbb", )))
None, ("bbb", )))

blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
Expand All @@ -933,7 +933,7 @@ def test_mm_prefix_caching():
assert len(block_hashes) == 4
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
("ccc", )))
None, ("ccc", )))

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
Expand Down Expand Up @@ -977,12 +977,14 @@ def test_cache_key_salting():
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None,
("salt1", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None,
None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
None, None))

blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
Expand All @@ -1000,7 +1002,8 @@ def test_cache_key_salting():
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
assert len(block_hashes) == 4
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None,
None))

# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
Expand All @@ -1019,12 +1022,14 @@ def test_cache_key_salting():
block_hashes = req2.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), None,
("salt2", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None,
None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
None, None))


def test_prefill_not_enough_free_blocks_with_computed_blocks():
Expand Down
10 changes: 0 additions & 10 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,16 +1552,6 @@ def _set_default_args(self, usage_context: UsageContext,
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True

# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if (self.enable_prompt_embeds
and self.enable_prefix_caching is not False):
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False

if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
Expand Down
Loading