Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/source/design/v1/prefix_caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified

* Parent hash value: The hash value of the parent hash block.
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments.

> **Note 1:** We only cache full blocks.
Expand Down Expand Up @@ -76,6 +76,24 @@ Block 3

In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow.

**Cache Isolation for Security**
To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance.

```json
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Here is a document with details about the world series: ..."},
{"role": "user", "content": "Who won the world series in 2020?"}
],
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ=="
}
```

With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others.

> **Note:** Cache isolation is not supported in engine V0.
## Data Structure

The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):
Expand Down
40 changes: 40 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():

assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05


def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()

mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)

# Test cache_salt
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
)

# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert "cache_salt" not in mock_engine.generate.call_args.args[0]

# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
12 changes: 10 additions & 2 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params,
None, 0.0, None)
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
0.0,
None,
cache_salt=None)

if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
Expand Down
43 changes: 42 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
mm_hashes=None,
cache_salt=None):
if mm_positions is None:
multi_modal_inputs = None
else:
Expand All @@ -45,6 +46,7 @@ def make_request(request_id,
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)


Expand Down Expand Up @@ -213,6 +215,45 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
assert next_mm_idx == 0


def test_generate_block_hash_extra_keys_cache_salt():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
cache_salt="salt",
)

# salt is added for the first token
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0)
assert extra_keys == ('salt', )
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0)
assert extra_keys == ('salt', )

# no salt added for other tokens
extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0)
assert extra_keys is None
extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys is None

# works together with other extra keys
request_mm = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
],
mm_hashes=["hash1"],
cache_salt="salt",
)

# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(
request_mm, 0, 5, 0)
assert extra_keys == ("hash1", "salt")
assert next_mm_idx == 1


@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_block_tokens(hash_fn):
parent_block_hash = 123
Expand Down
64 changes: 63 additions & 1 deletion tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
prompt_logprobs: Optional[int] = None):
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None):
if mm_positions is None:
multi_modal_inputs = None
else:
Expand All @@ -38,6 +39,7 @@ def make_request(request_id,
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)


Expand Down Expand Up @@ -603,6 +605,66 @@ def test_mm_prefix_caching():
assert num_computed_tokens == 3 * 16


def test_cache_key_salting():
"""
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)

# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None

blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0.num_computed_tokens = 59

# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0

# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None

# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks) == 3
assert num_computed_tokens == 3 * block_size

# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )


def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
This is a unit test that tests the correctness of the allocate_slots
Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest:
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)


Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)


Expand Down
7 changes: 6 additions & 1 deletion tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool,
reason should be "stop" (i.e. first control token causes stop
and is represented in output text)

* else, the detokenized string should be
* else, the detokenized string should be
<token><token>...<token> and the finish reason should be "stop"
(i.e. first control token causes stop but is not represented
in output text.)
Expand Down Expand Up @@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
mm_placeholders=None,
eos_token_id=eos_token_id,
lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors):
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(),
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
Expand Down
32 changes: 28 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ValidationInfo, field_validator, model_validator)
from typing_extensions import TypeAlias

from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -408,6 +409,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -726,6 +736,20 @@ def check_generation_prompt(cls, data):
"`add_generation_prompt` to True.")
return data

@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
Expand Down Expand Up @@ -1622,9 +1646,9 @@ class TranscriptionRequest(OpenAIBaseModel):

# doc: begin-transcription-extra-params
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
Expand All @@ -1642,15 +1666,15 @@ class TranscriptionRequest(OpenAIBaseModel):
"""

top_p: Optional[float] = None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""

top_k: Optional[int] = None
"""Limits sampling to the `k` most probable tokens at each step."""

min_p: Optional[float] = None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""

Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ async def _preprocess_chat(
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs

if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt

return conversation, [request_prompt], [engine_prompt]

def _log_inputs(
Expand Down
Loading