diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index ec1f3cb8d64a..ec661d8ec641 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified * Parent hash value: The hash value of the parent hash block. * Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision. -* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below). +* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments. > **Note 1:** We only cache full blocks. @@ -76,6 +76,24 @@ Block 3 In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow. +**Cache Isolation for Security** +To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance. + +```json +{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Here is a document with details about the world series: ..."}, + {"role": "user", "content": "Who won the world series in 2020?"} + ], + "cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==" +} +``` + +With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others. + +> **Note:** Cache isolation is not supported in engine V0. + ## Data Structure The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 19d16713b209..5e11af8cf892 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + +def test_serving_chat_did_set_correct_cache_salt(): + mock_model_config = MockModelConfig() + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test cache_salt + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + # By default cache_salt in the engine prompt is not set + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert "cache_salt" not in mock_engine.generate.call_args.args[0] + + # Test with certain cache_salt + req.cache_salt = "test_salt" + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f8e213b9ca48..079100e78b5f 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", prompt_token_ids, None, None, None, params, - None, 0.0, None) + request = EngineCoreRequest("", + prompt_token_ids, + None, + None, + None, + params, + None, + 0.0, + None, + cache_salt=None) if fast is None: detokenizer = IncrementalDetokenizer.from_new_request( diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e73e08e74b0d..e8069b8c6d7f 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -29,7 +29,8 @@ def make_request(request_id, prompt_token_ids, mm_positions=None, - mm_hashes=None): + mm_hashes=None, + cache_salt=None): if mm_positions is None: multi_modal_inputs = None else: @@ -45,6 +46,7 @@ def make_request(request_id, eos_token_id=100, arrival_time=0, lora_request=None, + cache_salt=cache_salt, ) @@ -213,6 +215,45 @@ def test_generate_block_hash_extra_keys_no_mm_inputs(): assert next_mm_idx == 0 +def test_generate_block_hash_extra_keys_cache_salt(): + request = make_request( + request_id=0, + prompt_token_ids=[_ for _ in range(6)], + mm_positions=None, + mm_hashes=None, + cache_salt="salt", + ) + + # salt is added for the first token + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0) + assert extra_keys == ('salt', ) + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0) + assert extra_keys == ('salt', ) + + # no salt added for other tokens + extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0) + assert extra_keys is None + extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0) + assert extra_keys is None + + # works together with other extra keys + request_mm = make_request( + request_id=0, + prompt_token_ids=[_ for _ in range(20)], + mm_positions=[ + PlaceholderRange(offset=0, length=5), + ], + mm_hashes=["hash1"], + cache_salt="salt", + ) + + # Test with no extra keys + extra_keys, next_mm_idx = generate_block_hash_extra_keys( + request_mm, 0, 5, 0) + assert extra_keys == ("hash1", "salt") + assert next_mm_idx == 1 + + @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_hash_block_tokens(hash_fn): parent_block_hash = 123 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b2e8ff61450c..ae4bd95d22aa 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -21,7 +21,8 @@ def make_request(request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, - prompt_logprobs: Optional[int] = None): + prompt_logprobs: Optional[int] = None, + cache_salt: Optional[str] = None): if mm_positions is None: multi_modal_inputs = None else: @@ -38,6 +39,7 @@ def make_request(request_id, eos_token_id=100, arrival_time=0, lora_request=None, + cache_salt=cache_salt, ) @@ -603,6 +605,66 @@ def test_mm_prefix_caching(): assert num_computed_tokens == 3 * 16 +def test_cache_key_salting(): + """ + This tests that cache salts are applied during hashing and the cache + is separated cache as expected. + """ + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, 11), + max_model_len=8192, + enable_caching=True, + ) + + # 3 complete blocks and an incomplete block with 11 tokens. + common_token_ids = [i for i in range(3) for _ in range(block_size)] + token_ids = common_token_ids + [3] * 11 + req0 = make_request("0", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks + assert num_computed_tokens == 0 + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt1", ) + assert block_hashes[1].extra_keys is None + assert block_hashes[2].extra_keys is None + + blocks = manager.allocate_slots(req0, 59, computed_blocks) + assert [b.block_id for b in blocks] == [1, 2, 3, 4] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.allocate_slots(req0, 5) + assert new_blocks is not None and len(new_blocks) == 0 + + # Now one more block that should not have extra keys. + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys is None + + # Test cache hit with a new request that has the same salt. + token_ids = common_token_ids + [4] * 11 + req1 = make_request("1", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + # Should match only a prefix of 3 blocks. + assert len(computed_blocks) == 3 + assert num_computed_tokens == 3 * block_size + + # Test cache miss with same content but different salt. + token_ids = common_token_ids + [4] * 11 + req2 = make_request("2", token_ids, cache_salt="salt2") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks) == 0 + assert num_computed_tokens == 0 + block_hashes = manager.req_to_block_hashes[req2.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt2", ) + + def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ This is a unit test that tests the correctness of the allocate_slots diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 30fa9e371ad1..dcf494825b0d 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest: eos_token_id=None, arrival_time=time.time(), lora_request=None, + cache_salt=None, ) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 8cc36fa163f7..5514a328497f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -43,6 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: eos_token_id=None, arrival_time=time.time(), lora_request=None, + cache_salt=None, ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index d2bb7d88fef2..fac701c4ca35 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool, reason should be "stop" (i.e. first control token causes stop and is represented in output text) - * else, the detokenized string should be + * else, the detokenized string should be ... and the finish reason should be "stop" (i.e. first control token causes stop but is not represented in output text.) @@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool, mm_placeholders=None, eos_token_id=eos_token_id, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors): mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams(), ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d444442a9762..389557dfb7c3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -14,6 +14,7 @@ ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias +from vllm import envs from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger from vllm.pooling_params import PoolingParams @@ -408,6 +409,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) # doc: end-chat-completion-extra-params @@ -726,6 +736,20 @@ def check_generation_prompt(cls, data): "`add_generation_prompt` to True.") return data + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1622,9 +1646,9 @@ class TranscriptionRequest(OpenAIBaseModel): # doc: begin-transcription-extra-params stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False @@ -1642,7 +1666,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ top_p: Optional[float] = None - """Enables nucleus (top-p) sampling, where tokens are selected from the + """Enables nucleus (top-p) sampling, where tokens are selected from the smallest possible set whose cumulative probability exceeds `p`. """ @@ -1650,7 +1674,7 @@ class TranscriptionRequest(OpenAIBaseModel): """Limits sampling to the `k` most probable tokens at each step.""" min_p: Optional[float] = None - """Filters out tokens with a probability lower than `min_p`, ensuring a + """Filters out tokens with a probability lower than `min_p`, ensuring a minimum likelihood threshold during sampling. """ diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c3121eff562d..6123811aabe1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -470,6 +470,9 @@ async def _preprocess_chat( if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return conversation, [request_prompt], [engine_prompt] def _log_inputs( diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 970b36bca9be..167189ed108e 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -28,6 +28,11 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + class TokensPrompt(TypedDict): """Schema for a tokenized prompt.""" @@ -52,6 +57,11 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ @@ -141,11 +151,17 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, + cache_salt: Optional[str] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -154,6 +170,8 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids + if cache_salt is not None: + inputs["cache_salt"] = cache_salt return inputs @@ -217,7 +235,7 @@ def zip_enc_dec_prompts( """ Zip encoder and decoder prompts together into a list of :class:`ExplicitEncoderDecoderPrompt` instances. - + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same dictionary will be used for every encoder/decoder prompt. If an iterable is provided, it will be zipped with the encoder/decoder prompts. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 56b60b893913..83e6907f8c49 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,7 +17,8 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt +from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, + is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -283,6 +284,29 @@ async def _process_multimodal_async( return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) + def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, + ParsedTextPrompt, + ParsedTokensPrompt]): + prompt_text = None + prompt_token_ids = None + token_type_ids = None + cache_salt = None + + if parsed_prompt["type"] == "str": + prompt_text = parsed_prompt["content"] + else: + cache_salt = parsed_prompt["content"].get("cache_salt") + if parsed_prompt["type"] == "text": + prompt_text = parsed_prompt["content"]["prompt"] + elif parsed_prompt["type"] == "tokens": + prompt_token_ids = parsed_prompt["content"].get( + "prompt_token_ids") + token_type_ids = parsed_prompt["content"].get("token_type_ids") + else: + assert_never(parsed_prompt) + + return prompt_text, prompt_token_ids, token_type_ids, cache_salt + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, @@ -304,70 +328,36 @@ def _prompt_to_llm_inputs( * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) - - if parsed["type"] == "str": - prompt_text = parsed["content"] - prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ + self._get_prompt_data(parsed) + + # If multimodal data is present, process and return immediately + if parsed["type"] != "str" and parsed["content"].get( + "multi_modal_data") is not None: + inputs = self._process_multimodal( + prompt_text if prompt_text is not None else prompt_token_ids, + parsed["content"]["multi_modal_data"], + parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, + return_mm_hashes=return_mm_hashes, ) + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + return inputs - if parsed["type"] == "tokens": - tokens_content = parsed["content"] - - prompt_token_ids = tokens_content["prompt_token_ids"] - token_type_ids = tokens_content.get("token_type_ids") - multi_modal_data = tokens_content.get("multi_modal_data") - mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return self._process_multimodal( - prompt_token_ids, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - - return token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - ) - - if parsed["type"] == "text": - text_content = parsed["content"] - - prompt_text = text_content["prompt"] - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return self._process_multimodal( - prompt_text, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - + if prompt_token_ids is None: prompt_token_ids = self._tokenize_prompt( prompt_text, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - assert_never(parsed) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + cache_salt=cache_salt, + ) async def _prompt_to_llm_inputs_async( self, @@ -379,64 +369,35 @@ async def _prompt_to_llm_inputs_async( """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) - if parsed["type"] == "str": - prompt_text = parsed["content"] - prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ + self._get_prompt_data(parsed) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, + if parsed["type"] != "str" and parsed["content"].get( + "multi_modal_data") is not None: + inputs = await self._process_multimodal_async( + prompt_token_ids if prompt_text is None else prompt_text, + parsed["content"]["multi_modal_data"], + parsed["content"].get("mm_processor_kwargs"), + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + return inputs - if parsed["type"] == "tokens": - tokens_content = parsed["content"] - - prompt_token_ids = tokens_content["prompt_token_ids"] - multi_modal_data = tokens_content.get("multi_modal_data") - mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return await self._process_multimodal_async( - prompt_token_ids, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - - return token_inputs(prompt_token_ids=prompt_token_ids) - - if parsed["type"] == "text": - text_content = parsed["content"] - - prompt_text = text_content["prompt"] - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return await self._process_multimodal_async( - prompt_text, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - + if prompt_token_ids is None: prompt_token_ids = await self._tokenize_prompt_async( prompt_text, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - assert_never(parsed) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + cache_salt=cache_salt, + ) def _build_enc_dec_llm_inputs( self, @@ -516,6 +477,11 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( mm_hashes=inputs["mm_hashes"], mm_placeholders=inputs["mm_placeholders"], ) + + cache_salt = inputs.get("cache_salt") + if cache_salt is not None: + decoder_inputs["cache_salt"] = cache_salt + elif inputs["type"] == "token": # Text-only inputs encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6855808e8e44..978fb4231939 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -826,6 +826,11 @@ class MultiModalInputs(TypedDict): :code:`prompt_token_ids`. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + class MultiModalEncDecInputs(MultiModalInputs): """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d6ba8f1bcffe..e8745a8f1f90 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1789,7 +1789,7 @@ def create_encoder_prompt( mm_data: MultiModalDataDict, ) -> Union[str, list[int]]: """ - Create input prompt for the encoder. HF processor will be applied on + Create input prompt for the encoder. HF processor will be applied on this prompt during profiling and generation. """ raise NotImplementedError diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3026ecc1c968..27c515835087 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -275,7 +275,10 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. - return bool(request.mm_positions) or (request.lora_request is not None) + # Request with provided cache salt need to include the salt. + return bool(request.mm_positions) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -380,8 +383,10 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) + cache_salt_keys: list[str] = [request.cache_salt] if ( + start_token_idx == 0 and request.cache_salt) else [] - extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys if not extra_keys: return None, new_start_mm_idx @@ -657,10 +662,10 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ Only models with one type of KV cache are supported yet. This function tries - to convert the KV cache specs to one type if the model is a hybrid model + to convert the KV cache specs to one type if the model is a hybrid model with multiple type of KV cache. It will convert all SlidingWindowSpec to FullAttentionSpec if both types are present. - + Args: kv_cache_spec: The kv cache spec of each attention layer in the model """ diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0474669610cd..e33d1a1e5dcd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -57,6 +57,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] + cache_salt: Optional[str] # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b98a31773a15..27d70a781471 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -317,6 +317,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, + cache_salt=decoder_inputs.get("cache_salt"), ) def _validate_model_inputs(self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b9b666f936a..fde366d61c7d 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -29,6 +29,7 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, + cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -51,6 +52,7 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 + self.cache_salt: Optional[str] = cache_salt # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -89,6 +91,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), + cache_salt=request.cache_salt, ) def append_output_token_ids(