From df8dcc5a4a384486e11b76fc117c1e7211e1f55f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Aug 2025 14:52:04 +0000 Subject: [PATCH 01/14] [Core] Separate MM IPC cache from processor cache Signed-off-by: DarkLight1337 --- docs/configuration/conserving_memory.md | 33 +++-- docs/configuration/optimization.md | 84 ++++++------ examples/offline_inference/mistral-small.py | 9 +- examples/offline_inference/vision_language.py | 8 +- tests/models/utils.py | 7 +- tests/multimodal/test_cache.py | 51 ++++++++ tests/multimodal/test_processing.py | 48 +------ vllm/config.py | 43 +++++- vllm/engine/arg_utils.py | 26 ++-- vllm/entrypoints/cli/serve.py | 12 +- vllm/envs.py | 6 +- vllm/multimodal/cache.py | 95 ++++++++++++++ vllm/multimodal/processing.py | 53 +------- vllm/v1/core/kv_cache_utils.py | 4 +- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core.py | 10 +- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/engine/mm_input_cache.py | 91 ------------- vllm/v1/engine/mm_ipc_cache.py | 122 ++++++++++++++++++ vllm/v1/engine/processor.py | 15 +-- 20 files changed, 428 insertions(+), 293 deletions(-) create mode 100644 tests/multimodal/test_cache.py create mode 100644 vllm/multimodal/cache.py delete mode 100644 vllm/v1/engine/mm_input_cache.py create mode 100644 vllm/v1/engine/mm_ipc_cache.py diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 4d5c961af98f..258a97a1926b 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,8 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). +- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process). +- (Multi-modal models only) you can set the size of multi-modal IPC cache by setting `mm_ipc_cache_gb` engine argument (default 4 GiB per engine core process). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits @@ -129,20 +130,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory. Here are some examples: -??? code - - ```python - from vllm import LLM +```python +from vllm import LLM - # Available for Qwen2-VL series models - llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_kwargs={ - "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 - }) - - # Available for InternVL series models - llm = LLM(model="OpenGVLab/InternVL2-2B", - mm_processor_kwargs={ - "max_dynamic_patch": 4, # Default is 12 - }) - ``` +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) + +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 811925c19e63..f9d7a13e2248 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -2,6 +2,9 @@ This guide covers optimization strategies and performance tuning for vLLM V1. +!!! tip + Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory + ## Preemption Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. @@ -126,62 +129,63 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. -## Reducing Memory Usage +## Multi-modal Processing -If you encounter out-of-memory issues, consider these strategies: +### Processor Cache -### Context Length and Batch Size +By default, the multi-modal processor cache is enabled to avoid repeatedly processing +the same multi-modal inputs via Hugging Face `AutoProcessor`, +which commonly occurs in multi-turn conversations. -You can reduce memory usage by limiting the context length and batch size: +You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable +(default 4 GiB per API process). -```python -from vllm import LLM +If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - max_model_len=2048, # Limit context window - max_num_seqs=4 # Limit batch size -) +```python +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + disable_mm_preprocessor_cache=True) ``` -### Adjust CUDA Graph Compilation +### IPC Cache -CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: +By default, the multi-modal IPC cache is enabled to avoid repeatedly transferring +the same multi-modal inputs between API and engine core processes, +which commonly occurs in multi-turn conversations. -```python -from vllm import LLM -from vllm.config import CompilationConfig, CompilationLevel +You can adjust the size of the cache by setting the `mm_ipc_cache_gb` engine argument +(default 4 GiB per engine core process). -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes - ) -) -``` - -Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: +If you do not benefit much from the cache, you can disable it completely by setting it to `0`: ```python -from vllm import LLM +# Use a larger IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_ipc_cache_gb=8) -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True # Disable CUDA graph compilation -) +# Fully disable the IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_ipc_cache_gb=0) ``` -### Multimodal Models +### Parallel Processing -For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: +You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). +This is useful when input processing (which is run inside the API server) +becomes a bottleneck compared to model execution (which is run inside engine core) +and you have excess CPU capacity. -```python -from vllm import LLM +```console +# Run 4 API processes and 1 engine core process +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -# Accept up to 2 images per prompt -llm = LLM( - model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"image": 2} -) +# Run 4 API processes and 2 engine core processes +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 ``` + +!!! note + API server scale-out is only available for online inference. + +!!! note + Multi-modal IPC cache is disabled when API server scale-out is enabled + because it requires a one-to-one correspondance between API and engine core processes. diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index a38fc9216d40..57c8dce78e2c 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -69,6 +69,7 @@ def run_simple_demo(args: argparse.Namespace): max_num_seqs=2, tensor_parallel_size=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_ipc_cache_gb=0 if args.disable_mm_ipc_cache else 4, ) prompt = "Describe this image in one sentence." @@ -106,6 +107,7 @@ def run_advanced_demo(args: argparse.Namespace): max_model_len=max_img_per_msg * max_tokens_per_img, tensor_parallel_size=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_ipc_cache_gb=0 if args.disable_mm_ipc_cache else 4, ) prompt = "Describe the following image." @@ -166,7 +168,12 @@ def parse_args(): parser.add_argument( "--disable-mm-preprocessor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", + ) + parser.add_argument( + "--disable-mm-ipc-cache", + action="store_true", + help="If True, disables caching of multi-modal transfer from P0 to P1.", ) return parser.parse_args() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 16bb3712f551..e15aac74bd43 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1565,7 +1565,12 @@ def parse_args(): parser.add_argument( "--disable-mm-preprocessor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", + ) + parser.add_argument( + "--disable-mm-ipc-cache", + action="store_true", + help="If True, disables caching of multi-modal transfer from P0 to P1.", ) parser.add_argument( @@ -1604,6 +1609,7 @@ def main(args): engine_args = asdict(req_data.engine_args) | { "seed": args.seed, "disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, + "mm_ipc_cache_gb": 0 if args.disable_mm_ipc_cache else 4, } llm = LLM(**engine_args) diff --git a/tests/models/utils.py b/tests/models/utils.py index 3cd0721be1b6..86f71adb37fe 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from vllm.config import ModelConfig, RunnerOption +from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs @@ -256,11 +256,12 @@ def check_logprobs_close( def build_model_context( model_id: str, runner: RunnerOption = "auto", - dtype: Union[str, torch.dtype] = "auto", + dtype: ModelDType = "auto", model_config_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, disable_mm_preprocessor_cache: bool = True, + mm_ipc_cache_gb: int = 0, ): """Creates an InputContext for a given model. @@ -278,6 +279,7 @@ def build_model_context( model_info.check_transformers_version(on_fail="skip") model_config_kwargs = model_config_kwargs or {} + limit_mm_per_prompt = limit_mm_per_prompt or {} model_config = ModelConfig( model_id, runner=runner, @@ -290,6 +292,7 @@ def build_model_context( mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + mm_ipc_cache_gb=mm_ipc_cache_gb, hf_overrides=model_info.hf_overrides, **model_config_kwargs, ) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py new file mode 100644 index 000000000000..e07b73bd257d --- /dev/null +++ b/tests/multimodal/test_cache.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size, ), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size) for key, size in size_by_key.items() + ]) + + +def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): + return MultiModalKwargs.from_items([ + _dummy_item(modality, size_by_key) + for modality, size_by_key in size_by_key_modality.items() + ]) + + +# yapf: disable +@pytest.mark.parametrize( + ("item", "expected_size"), + [ + (_dummy_item("a", {"a1": 100}), 100), + (_dummy_item("a", {"a1": 100, "a2": 110}), 210), + (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + ], +) +# yapf: enable +def test_cache_item_size(item, expected_size): + cache = MultiModalCache.get_lru_cache(2048, type(item)) + + cache[""] = item + assert cache.currsize == expected_size + + cache[""] = MultiModalCacheItemMetadata.wraps(item) + assert cache.currsize == expected_size diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 508c773b8aed..cb489c47fd8f 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -6,20 +6,15 @@ import numpy as np import pytest -import torch from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalSharedField) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - ProcessingCache, PromptIndexTargets, - PromptInsertion, PromptReplacement, - apply_text_matches, + PromptIndexTargets, PromptInsertion, + PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, @@ -902,45 +897,6 @@ def test_find_mm_placeholders( assert result == expected -def _dummy_elem(modality: str, key: str, size: int): - return MultiModalFieldElem( - modality=modality, - key=key, - data=torch.empty((size, ), dtype=torch.int8), - field=MultiModalSharedField(1), - ) - - -def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() - ]) - - -def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs.from_items([ - _dummy_item(modality, size_by_key) - for modality, size_by_key in size_by_key_modality.items() - ]) - - -# yapf: disable -@pytest.mark.parametrize( - ("item", "expected_size"), - [ - (_dummy_item("a", {"a1": 100}), 100), - (_dummy_item("a", {"a1": 100, "a2": 110}), 210), - (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - ], -) -# yapf: enable -def test_cache_item_size(item, expected_size): - cache = ProcessingCache.get_lru_cache(2048, type(item)) - cache[""] = item - - assert cache.currsize == expected_size - - @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), diff --git a/vllm/config.py b/vllm/config.py index 5c300e327397..af95487c7148 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -444,8 +444,16 @@ class ModelConfig: model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ disable_mm_preprocessor_cache: bool = False - """If `True`, disable caching of the multi-modal preprocessor/mapper (not - recommended).""" + """If `True`, disable caching of the multi-modal processor.""" + mm_ipc_cache_gb: int = 4 + """The size (in GiB) of the multi-modal IPC cache, which is used to avoid + transfer of past multi-modal inputs between API and engine core processes. + + Since the cache is located in engine core, it is duplicated for each + engine core process, resulting in a total memory usage of + `mm_ipc_cache_gb * data_parallel_size`. + + Set to `0` to disable this cache completely (not recommended).""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -884,15 +892,16 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: mm_processor_kwargs=self.mm_processor_kwargs, disable_mm_preprocessor_cache=self. disable_mm_preprocessor_cache, + mm_ipc_cache_gb=self.mm_ipc_cache_gb, interleave_mm_strings=self.interleave_mm_strings) return None - def set_disable_mm_preprocessor_cache(self, value: bool) -> None: + def set_mm_ipc_cache_gb(self, value: int) -> None: mm_config = self.get_multimodal_config() - self.disable_mm_preprocessor_cache = value - mm_config.disable_mm_preprocessor_cache = value + self.mm_ipc_cache_gb = value + mm_config.mm_ipc_cache_gb = value def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( @@ -1686,6 +1695,16 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def processor_return_mm_hashes(self) -> bool: + """Whether the multi-modal processor should output hashes.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return (not mm_config.disable_mm_preprocessor_cache + or mm_config.mm_ipc_cache_gb > 0) + @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -3363,7 +3382,19 @@ class MultiModalConfig: disable_mm_preprocessor_cache: bool = False """ - If `True`, disable caching of the processed multi-modal inputs. + If `True`, disable caching of the multi-modal processor. + """ + + mm_ipc_cache_gb: int = 4 + """ + The size (in GiB) of the multi-modal IPC cache, which is used to avoid + transfer of past multi-modal inputs between API and engine core processes. + + Since the cache is located in engine core, it is duplicated for each + engine core process, resulting in a total memory usage of + `mm_ipc_cache_gb * data_parallel_size`. + + Set to `0` to disable this cache completely (not recommended). """ interleave_mm_strings: bool = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5eb9660cd1e8..25338f8459c6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -360,6 +360,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = \ MultiModalConfig.disable_mm_preprocessor_cache + mm_ipc_cache_gb: int = MultiModalConfig.mm_ipc_cache_gb # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -722,6 +723,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--disable-mm-preprocessor-cache", **multimodal_kwargs["disable_mm_preprocessor_cache"]) + multimodal_group.add_argument("--mm-ipc-cache-gb", + **multimodal_kwargs["mm_ipc_cache_gb"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -923,6 +926,7 @@ def create_model_config(self) -> ModelConfig: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, + mm_ipc_cache_gb=self.mm_ipc_cache_gb, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1230,17 +1234,17 @@ def create_engine_config( enable_multimodal_encoder_data_parallel, ) - supports_mm_preprocessor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not supports_mm_preprocessor_cache - and model_config.is_multimodal_model - and not model_config.disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with data parallelism when there does not exist a " - "one-to-one correspondance between API process and " - "EngineCore process, so the cache will be disabled.") - model_config.set_disable_mm_preprocessor_cache(True) + if model_config.is_multimodal_model: + dp_supports_mm_ipc_cache = (self.data_parallel_size == 1 + or data_parallel_external_lb) + if (not dp_supports_mm_ipc_cache + and model_config.mm_ipc_cache_gb > 0): + logger.warning( + "Multi-modal IPC cache is disabled because " + "it is not compatible with data parallelism when " + "there does not exist a one-to-one correspondance " + "between API and engine core processes.") + model_config.set_mm_ipc_cache_gb(0) speculative_config = self.create_speculative_config( target_model_config=model_config, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9762a1de9edd..b3ed79e67cef 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers = args.api_server_count assert num_api_servers > 0 - orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache + orig_mm_ipc_cache_gb = args.mm_ipc_cache_gb if num_api_servers > 1: setup_multiprocess_prometheus() # Not compatible with API server scale-out - args.disable_mm_preprocessor_cache = True + args.mm_ipc_cache_gb = 0 listen_address, sock = setup_server(args) @@ -161,11 +161,9 @@ def run_multi_api_server(args: argparse.Namespace): raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " "with api_server_count > 1") - if model_config.is_multimodal_model and not ( - orig_disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with api_server_count > 1, so the cache will be disabled.") + if model_config.is_multimodal_model and orig_mm_ipc_cache_gb > 0: + logger.warning("Multi-modal IPC cache is disabled because it " + "is not compatible with `api_server_count > 1`.") executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/envs.py b/vllm/envs.py index 78f955f78a98..b51dd364285a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,7 +64,7 @@ VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" - VLLM_MM_INPUT_CACHE_GIB: int = 8 + VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -552,8 +552,8 @@ def get_vllm_port() -> Optional[int]: "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - # Cache size (in GiB) for multimodal input cache - # Default is 4 GiB + # Cache size (in GiB per API process) for multimodal input cache + # Default is 4 GiB per API process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py new file mode 100644 index 000000000000..262b22e554b9 --- /dev/null +++ b/vllm/multimodal/cache.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys +from collections.abc import Mapping +from dataclasses import dataclass +from typing import TypeVar, Union + +import torch + +from vllm.jsontree import json_map_leaves, json_reduce_leaves +from vllm.logger import init_logger +from vllm.utils import GiB_bytes, LRUCache + +from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors + +logger = init_logger(__name__) + + +@dataclass +class MultiModalCacheItemMetadata: + size: int + + @classmethod + def wraps(cls, value: "MultiModalCacheValue"): + return cls(size=MultiModalCache.get_item_size(value)) + + +MultiModalCacheValue = Union[ + MultiModalKwargs, + MultiModalKwargsItem, + Mapping[str, NestedTensors], + MultiModalCacheItemMetadata, +] + +_V = TypeVar("_V", bound=MultiModalCacheValue) + + +class MultiModalCache: + + @classmethod + def get_leaf_size( + cls, + leaf: object, + *, + debug: bool = False, + ) -> int: + # MultiModalKwargs is not a subclass of dict + if isinstance(leaf, MultiModalKwargs): + return cls.get_item_size(leaf.data, debug=debug) + + # MultiModalKwargsItem is not a subclass of dict + if isinstance(leaf, MultiModalKwargsItem): + leaf_data = {k: v.data for k, v in leaf.items()} + return cls.get_item_size(leaf_data, debug=debug) + + # sys.getsizeof doesn't work for tensors + if isinstance(leaf, torch.Tensor): + return leaf.nbytes + + if isinstance(leaf, MultiModalCacheItemMetadata): + return leaf.size + + return sys.getsizeof(leaf) + + @classmethod + def get_item_size( + cls, + value: MultiModalCacheValue, + *, + debug: bool = False, + ) -> int: + size = json_reduce_leaves( + lambda a, b: a + b, + json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug), + value), + ) + + if debug: + logger.debug("Calculated size of %s to be %.2f GiB", type(value), + size / GiB_bytes) + + return size + + @classmethod + def get_lru_cache( + cls, + capacity_gb: float, + value_type: type[_V], + *, + debug: bool = False, + ) -> LRUCache[str, _V]: + return LRUCache( + GiB_bytes * capacity_gb, + getsizeof=lambda x: cls.get_item_size(x, debug=debug), + ) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 46240855d12a..0378539495fd 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import sys from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, @@ -16,16 +15,16 @@ from typing_extensions import assert_never from vllm.inputs import InputProcessingContext -from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby +from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby +from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors, PlaceholderRange) + MultiModalKwargsItem, PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -888,9 +887,6 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") - - class ProcessingCacheOptionalItem(NamedTuple): key: str value: Optional[MultiModalKwargsItem] @@ -901,48 +897,7 @@ class ProcessingCacheItem(NamedTuple): value: MultiModalKwargsItem -class ProcessingCache: - - @staticmethod - def get_lru_cache( - capacity_gb: float, - value_type: type[_V], - *, - debug: bool = False, - ) -> LRUCache[str, _V]: - - def get_leaf_size(leaf: object) -> int: - # MultiModalKwargs is not a subclass of dict - if isinstance(leaf, MultiModalKwargs): - return get_item_size(leaf.data) - - # MultiModalKwargsItem is not a subclass of dict - if isinstance(leaf, MultiModalKwargsItem): - leaf_data = {k: v.data for k, v in leaf.items()} - return get_item_size(leaf_data) - - # sys.getsizeof doesn't work for tensors - if isinstance(leaf, torch.Tensor): - return leaf.nbytes - - return sys.getsizeof(leaf) - - def get_item_size( - value: Union[MultiModalKwargs, MultiModalKwargsItem, - Mapping[str, NestedTensors]] - ) -> int: - size = json_reduce_leaves( - lambda a, b: a + b, - json_map_leaves(get_leaf_size, value), - ) - - if debug: - logger.debug("Calculated size of %s to be %.2f GiB", - type(value), size / GiB_bytes) - - return size - - return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size) +class ProcessingCache(MultiModalCache): def __init__( self, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index eab1560b1a18..ad26f20aa3a4 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( "The number of multi-modal positions and hashes must match. This " - "is likely because you do not enable MM preprocessor hashing. " - "Please set disable_mm_preprocessor_cache=False.") + "is likely because you did not enable MM IPC caching. " + "Please set `mm_ipc_cache_gb > 0`.") # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 45f450291ab6..049b17031ec9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -567,7 +567,7 @@ async def stop_profile(self) -> None: async def reset_mm_cache(self) -> None: self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache_client.reset() + self.processor.mm_ipc_cache.reset() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 79c47e102888..43ad5d5f5217 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache +from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -124,9 +124,7 @@ def __init__(self, log_stats=self.log_stats, ) - # Setup MM Input Mapper. - self.mm_input_cache_server = MirroredProcessingCache( - vllm_config.model_config) + self.mm_ipc_cache = MultiModalIPCCache(vllm_config.model_config) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -347,7 +345,7 @@ def reset_mm_cache(self): logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + self.mm_ipc_cache.reset() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -413,7 +411,7 @@ def preprocess_add_request( # Note on thread safety: no race condition. # `mm_input_cache_server` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( + request.mm_inputs = self.mm_ipc_cache.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index efbdffbc0900..10b2e380f0bb 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -272,7 +272,7 @@ def stop_profile(self): def reset_mm_cache(self): self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache_client.reset() + self.processor.mm_ipc_cache.reset() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index abe98a13dfd3..000000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import Optional - -from vllm.envs import VLLM_MM_INPUT_CACHE_GIB -from vllm.multimodal import MultiModalKwargs -from vllm.multimodal.processing import ProcessingCache -from vllm.utils import is_list_of - -# The idea of multimodal preprocessing caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- Client: -# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs -# with built-in caching functionality, with mm_hash as its identifier. -# - MirroredProcessingCache to keep track of the cached entries and -# determine whether to send the MultiModalKwargs to P1. -# -# -- Server: -# - MirroredProcessingCache to store the MultiModalKwargs from P0. -# -# The caching for both client and server is mirrored, and this allows us -# to avoid the serialization of "mm_inputs" (like pixel values) between -# client (=P0) and server (=P1) processes if the mm_hash is found in the client -# cache. - -# Both Client and Server must use the same cache size -# (to perform mirrored caching). This cache size is set by the environment -# variable VLLM_MM_INPUT_CACHE_GIB. - - -class MirroredProcessingCache: - - def __init__(self, model_config): - mm_config = model_config.multimodal_config - disable_mm_preprocessor_cache = ( - mm_config is not None and mm_config.disable_mm_preprocessor_cache) - self.use_cache = not disable_mm_preprocessor_cache - self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, - MultiModalKwargs) - - def get_and_update_p0( - self, - mm_inputs: Sequence[MultiModalKwargs], - mm_hashes: list[str], - ) -> Sequence[Optional[MultiModalKwargs]]: - assert len(mm_inputs) == len(mm_hashes) - - if not self.use_cache: - assert is_list_of(mm_inputs, MultiModalKwargs) - return mm_inputs - - full_mm_inputs = list[Optional[MultiModalKwargs]]() - for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - mm_input = None - else: - self.mm_cache[mm_hash] = mm_input - - full_mm_inputs.append(mm_input) - - return full_mm_inputs - - def get_and_update_p1( - self, - mm_inputs: Sequence[Optional[MultiModalKwargs]], - mm_hashes: list[str], - ) -> Sequence[MultiModalKwargs]: - assert len(mm_inputs) == len(mm_hashes) - - if not self.use_cache: - assert is_list_of(mm_inputs, MultiModalKwargs) - return mm_inputs - - full_mm_inputs = list[MultiModalKwargs]() - for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if mm_input is None: - mm_input = self.mm_cache[mm_hash] - else: - self.mm_cache[mm_hash] = mm_input - - full_mm_inputs.append(mm_input) - - return full_mm_inputs - - def reset(self) -> bool: - self.mm_cache.clear() - - return True diff --git a/vllm/v1/engine/mm_ipc_cache.py b/vllm/v1/engine/mm_ipc_cache.py new file mode 100644 index 000000000000..9de27b03c2d6 --- /dev/null +++ b/vllm/v1/engine/mm_ipc_cache.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional + +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata +from vllm.utils import is_list_of + +if TYPE_CHECKING: + from vllm.config import ModelConfig + +# The idea of multimodal IPC caching is based on having a client and +# a server, where the client executes in the frontend process (=P0) and the +# server in the core process (=P1). +# +# -- Client: +# - BaseMultiModalProcessor processes MultiModalData into MultiModalKwargs +# while outputting mm_hash as its identifier. +# - MultiModalIPCCacheLookup keeps track of the cached entries and +# determine whether to send the MultiModalKwargs to P1. +# +# -- Server: +# - MultiModalIPCCache stores the MultiModalKwargs from P0. +# +# The keys of MultiModalIPCCacheLookup (in the client) +# and the keys of MultiModalIPCCacheLookup (in the server) +# are mirrored, and this allows us to avoid the serialization of `mm_inputs` +# (like pixel values) between client (=P0) and server (=P1) processes if the +# `mm_hash` is found in the client cache. + +# Both Client and Server must use the same cache size (to remain mirrored). +# This cache size is set by the config variable `mm_ipc_cache_gb`. + + +class MultiModalIPCCacheLookup: + """Used by P0 to check whether multi-modal kwargs are cached in P1.""" + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.multimodal_config + + self.enabled = (mm_config is not None + and mm_config.mm_ipc_cache_gb > 0) + + self.mm_cache = MultiModalCache.get_lru_cache( + mm_config.mm_ipc_cache_gb if mm_config else 0, + MultiModalCacheItemMetadata, + ) + + def get_and_update( + self, + mm_inputs: Sequence[MultiModalKwargs], + mm_hashes: list[str], + ) -> Sequence[Optional[MultiModalKwargs]]: + assert len(mm_inputs) == len(mm_hashes) + + if not self.enabled: + assert is_list_of(mm_inputs, MultiModalKwargs) + return mm_inputs + + full_mm_inputs = list[Optional[MultiModalKwargs]]() + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + if self.mm_cache.get(mm_hash) is not None: + mm_input = None + else: + self.mm_cache[mm_hash] = \ + MultiModalCacheItemMetadata.wraps(mm_input) + + full_mm_inputs.append(mm_input) + + return full_mm_inputs + + def reset(self) -> bool: + self.mm_cache.clear() + + return True + + +class MultiModalIPCCache: + """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.multimodal_config + + self.enabled = (mm_config is not None + and mm_config.mm_ipc_cache_gb > 0) + + self.mm_cache = MultiModalCache.get_lru_cache( + mm_config.mm_ipc_cache_gb if mm_config else 0, + MultiModalKwargs, + ) + + def get_and_update( + self, + mm_inputs: Sequence[Optional[MultiModalKwargs]], + mm_hashes: list[str], + ) -> Sequence[MultiModalKwargs]: + assert len(mm_inputs) == len(mm_hashes) + + if not self.enabled: + assert is_list_of(mm_inputs, MultiModalKwargs) + return mm_inputs + + full_mm_inputs = list[MultiModalKwargs]() + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + if mm_input is None: + mm_input = self.mm_cache[mm_hash] + else: + self.mm_cache[mm_hash] = mm_input + + full_mm_inputs.append(mm_input) + + return full_mm_inputs + + def reset(self) -> bool: + self.mm_cache.clear() + + return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 692a7dd5640e..7b51908482f5 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,7 +19,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache +from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheLookup from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_outlines import ( @@ -50,11 +50,7 @@ def __init__( self.tokenizer, mm_registry) - self.mm_input_cache_client = MirroredProcessingCache(self.model_config) - - # Multi-modal hasher (for images) - self.use_hash = self.mm_input_cache_client.use_cache or \ - self.cache_config.enable_prefix_caching + self.mm_ipc_cache = MultiModalIPCCacheLookup(self.model_config) @property def mm_registry(self): @@ -256,11 +252,12 @@ def process_inputs( # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. + return_mm_hashes = self.model_config.processor_return_mm_hashes processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=self.use_hash, + return_mm_hashes=return_mm_hashes, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -312,7 +309,7 @@ def process_inputs( sorted_mm_hashes, ) = merge_and_sort_multimodal_metadata( decoder_inputs["mm_placeholders"], - decoder_inputs["mm_hashes"] if self.use_hash else None, + decoder_inputs["mm_hashes"] if return_mm_hashes else None, ) # The output of merged multi-modal processor (`decoder_mm_inputs`) @@ -339,7 +336,7 @@ def process_inputs( ] if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( + sorted_mm_inputs = self.mm_ipc_cache.get_and_update( orig_sorted_mm_inputs, sorted_mm_hashes) else: sorted_mm_inputs = orig_sorted_mm_inputs From b21561281c0a97134c6f18b38b337a2ebe4b094f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Aug 2025 15:19:44 +0000 Subject: [PATCH 02/14] Period Signed-off-by: DarkLight1337 --- docs/configuration/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index f9d7a13e2248..75ec9fd04c11 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -3,7 +3,7 @@ This guide covers optimization strategies and performance tuning for vLLM V1. !!! tip - Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory + Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory. ## Preemption From c6535eaf591a610232f9e0ebd5306ec859905032 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Aug 2025 17:30:36 +0000 Subject: [PATCH 03/14] Reorganize Signed-off-by: DarkLight1337 --- docs/configuration/optimization.md | 48 ++++++++++++++++-------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 75ec9fd04c11..9a14d93518d9 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -129,7 +129,31 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. -## Multi-modal Processing +## Input Processing + +### Parallel Processing + +You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). +This is useful when input processing (which is run inside the API server) +becomes a bottleneck compared to model execution (which is run inside engine core) +and you have excess CPU capacity. + +```console +# Run 4 API processes and 1 engine core process +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 + +# Run 4 API processes and 2 engine core processes +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 +``` + +!!! note + API server scale-out is only available for online inference. + +!!! note + [Multi-modal IPC cache](#ipc-cache) is disabled when API server scale-out is enabled + because it requires a one-to-one correspondance between API and engine core processes. + +## Multi-Modal Caching ### Processor Cache @@ -167,25 +191,3 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_ipc_cache_gb=0) ``` - -### Parallel Processing - -You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). -This is useful when input processing (which is run inside the API server) -becomes a bottleneck compared to model execution (which is run inside engine core) -and you have excess CPU capacity. - -```console -# Run 4 API processes and 1 engine core process -vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 - -# Run 4 API processes and 2 engine core processes -vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 -``` - -!!! note - API server scale-out is only available for online inference. - -!!! note - Multi-modal IPC cache is disabled when API server scale-out is enabled - because it requires a one-to-one correspondance between API and engine core processes. From 1855af4ef58c7764ff6abc8b4a9208994a0f274c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Aug 2025 17:42:20 +0000 Subject: [PATCH 04/14] Address comments Signed-off-by: DarkLight1337 --- vllm/config.py | 16 ++++++++++++++++ vllm/v1/engine/mm_ipc_cache.py | 24 ++++++------------------ vllm/v1/engine/processor.py | 3 ++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index af95487c7148..5f8287454f61 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1705,6 +1705,22 @@ def processor_return_mm_hashes(self) -> bool: return (not mm_config.disable_mm_preprocessor_cache or mm_config.mm_ipc_cache_gb > 0) + @property + def enable_mm_ipc_cache(self) -> bool: + """Whether the multi-modal IPC cache should be enabled.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return mm_config.mm_ipc_cache_gb > 0 + + def get_mm_ipc_cache_gb(self) -> int: + mm_config = self.multimodal_config + if mm_config is None: + return 0 + + return mm_config.mm_ipc_cache_gb + @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding diff --git a/vllm/v1/engine/mm_ipc_cache.py b/vllm/v1/engine/mm_ipc_cache.py index 9de27b03c2d6..91dadf0f5de4 100644 --- a/vllm/v1/engine/mm_ipc_cache.py +++ b/vllm/v1/engine/mm_ipc_cache.py @@ -39,13 +39,9 @@ class MultiModalIPCCacheLookup: def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - mm_config = model_config.multimodal_config - - self.enabled = (mm_config is not None - and mm_config.mm_ipc_cache_gb > 0) - + self.enabled = model_config.enable_mm_ipc_cache self.mm_cache = MultiModalCache.get_lru_cache( - mm_config.mm_ipc_cache_gb if mm_config else 0, + model_config.get_mm_ipc_cache_gb(), MultiModalCacheItemMetadata, ) @@ -72,11 +68,9 @@ def get_and_update( return full_mm_inputs - def reset(self) -> bool: + def reset(self) -> None: self.mm_cache.clear() - return True - class MultiModalIPCCache: """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" @@ -84,13 +78,9 @@ class MultiModalIPCCache: def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - mm_config = model_config.multimodal_config - - self.enabled = (mm_config is not None - and mm_config.mm_ipc_cache_gb > 0) - + self.enabled = model_config.enable_mm_ipc_cache self.mm_cache = MultiModalCache.get_lru_cache( - mm_config.mm_ipc_cache_gb if mm_config else 0, + model_config.get_mm_ipc_cache_gb(), MultiModalKwargs, ) @@ -116,7 +106,5 @@ def get_and_update( return full_mm_inputs - def reset(self) -> bool: + def reset(self) -> None: self.mm_cache.clear() - - return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7b51908482f5..3b2d79cd07fc 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -252,7 +252,8 @@ def process_inputs( # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - return_mm_hashes = self.model_config.processor_return_mm_hashes + return_mm_hashes = (self.model_config.processor_return_mm_hashes + or bool(self.cache_config.enable_prefix_caching)) processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, From 4e8fad96e5fc24b281234f9ad235bc761cc1c2eb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 4 Aug 2025 17:43:32 +0000 Subject: [PATCH 05/14] Update comment Signed-off-by: DarkLight1337 --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 43ad5d5f5217..467db42678e4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -409,7 +409,7 @@ def preprocess_add_request( if request.mm_hashes is not None: assert request.mm_inputs is not None # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, + # `mm_ipc_cache` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. request.mm_inputs = self.mm_ipc_cache.get_and_update( request.mm_inputs, request.mm_hashes) From b76dbf1ad03d1c2ec73dbcf49f9d849adec746cd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 6 Aug 2025 16:41:57 +0000 Subject: [PATCH 06/14] Typo Signed-off-by: DarkLight1337 --- vllm/v1/engine/mm_ipc_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/mm_ipc_cache.py b/vllm/v1/engine/mm_ipc_cache.py index 91dadf0f5de4..133823c34026 100644 --- a/vllm/v1/engine/mm_ipc_cache.py +++ b/vllm/v1/engine/mm_ipc_cache.py @@ -24,7 +24,7 @@ # - MultiModalIPCCache stores the MultiModalKwargs from P0. # # The keys of MultiModalIPCCacheLookup (in the client) -# and the keys of MultiModalIPCCacheLookup (in the server) +# and the keys of MultiModalIPCCache (in the server) # are mirrored, and this allows us to avoid the serialization of `mm_inputs` # (like pixel values) between client (=P0) and server (=P1) processes if the # `mm_hash` is found in the client cache. From 4f979743d334a881f3dc34c79d4a77e1ca43f24f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 01:58:51 +0000 Subject: [PATCH 07/14] Address comments Signed-off-by: DarkLight1337 --- docs/configuration/conserving_memory.md | 3 +-- docs/configuration/optimization.md | 23 +------------------ examples/offline_inference/mistral-small.py | 7 ------ examples/offline_inference/vision_language.py | 1 - tests/models/utils.py | 2 -- vllm/config.py | 23 +++++-------------- vllm/engine/arg_utils.py | 10 +++----- vllm/entrypoints/cli/serve.py | 11 +++++---- vllm/envs.py | 4 ++-- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/engine/core.py | 4 ++-- vllm/v1/engine/mm_ipc_cache.py | 17 +++++++------- vllm/v1/engine/processor.py | 4 ++-- 13 files changed, 32 insertions(+), 79 deletions(-) diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 258a97a1926b..dcaf1069bfdf 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,8 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process). -- (Multi-modal models only) you can set the size of multi-modal IPC cache by setting `mm_ipc_cache_gb` engine argument (default 4 GiB per engine core process). +- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process) - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 9a14d93518d9..75e77519c69f 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -162,7 +162,7 @@ the same multi-modal inputs via Hugging Face `AutoProcessor`, which commonly occurs in multi-turn conversations. You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable -(default 4 GiB per API process). +(default 4 GiB per API process + 4 GiB per engine core process). If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: @@ -170,24 +170,3 @@ If you do not benefit much from the cache, you can disable it completely via `di llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", disable_mm_preprocessor_cache=True) ``` - -### IPC Cache - -By default, the multi-modal IPC cache is enabled to avoid repeatedly transferring -the same multi-modal inputs between API and engine core processes, -which commonly occurs in multi-turn conversations. - -You can adjust the size of the cache by setting the `mm_ipc_cache_gb` engine argument -(default 4 GiB per engine core process). - -If you do not benefit much from the cache, you can disable it completely by setting it to `0`: - -```python -# Use a larger IPC cache -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_ipc_cache_gb=8) - -# Fully disable the IPC cache -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_ipc_cache_gb=0) -``` diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index 57c8dce78e2c..59ec22a1e9fa 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -69,7 +69,6 @@ def run_simple_demo(args: argparse.Namespace): max_num_seqs=2, tensor_parallel_size=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, - mm_ipc_cache_gb=0 if args.disable_mm_ipc_cache else 4, ) prompt = "Describe this image in one sentence." @@ -107,7 +106,6 @@ def run_advanced_demo(args: argparse.Namespace): max_model_len=max_img_per_msg * max_tokens_per_img, tensor_parallel_size=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, - mm_ipc_cache_gb=0 if args.disable_mm_ipc_cache else 4, ) prompt = "Describe the following image." @@ -170,11 +168,6 @@ def parse_args(): action="store_true", help="If True, disables caching of multi-modal processor.", ) - parser.add_argument( - "--disable-mm-ipc-cache", - action="store_true", - help="If True, disables caching of multi-modal transfer from P0 to P1.", - ) return parser.parse_args() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e15aac74bd43..373a98a8d666 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1609,7 +1609,6 @@ def main(args): engine_args = asdict(req_data.engine_args) | { "seed": args.seed, "disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, - "mm_ipc_cache_gb": 0 if args.disable_mm_ipc_cache else 4, } llm = LLM(**engine_args) diff --git a/tests/models/utils.py b/tests/models/utils.py index 86f71adb37fe..0ba65532ab8c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -261,7 +261,6 @@ def build_model_context( mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, disable_mm_preprocessor_cache: bool = True, - mm_ipc_cache_gb: int = 0, ): """Creates an InputContext for a given model. @@ -292,7 +291,6 @@ def build_model_context( mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, - mm_ipc_cache_gb=mm_ipc_cache_gb, hf_overrides=model_info.hf_overrides, **model_config_kwargs, ) diff --git a/vllm/config.py b/vllm/config.py index 5f8287454f61..337e6a67cb49 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -445,15 +445,6 @@ class ModelConfig: """ disable_mm_preprocessor_cache: bool = False """If `True`, disable caching of the multi-modal processor.""" - mm_ipc_cache_gb: int = 4 - """The size (in GiB) of the multi-modal IPC cache, which is used to avoid - transfer of past multi-modal inputs between API and engine core processes. - - Since the cache is located in engine core, it is duplicated for each - engine core process, resulting in a total memory usage of - `mm_ipc_cache_gb * data_parallel_size`. - - Set to `0` to disable this cache completely (not recommended).""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -892,16 +883,15 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: mm_processor_kwargs=self.mm_processor_kwargs, disable_mm_preprocessor_cache=self. disable_mm_preprocessor_cache, - mm_ipc_cache_gb=self.mm_ipc_cache_gb, interleave_mm_strings=self.interleave_mm_strings) return None - def set_mm_ipc_cache_gb(self, value: int) -> None: + def set_disable_mm_preprocessor_cache(self, value: bool) -> None: mm_config = self.get_multimodal_config() - self.mm_ipc_cache_gb = value - mm_config.mm_ipc_cache_gb = value + self.disable_mm_preprocessor_cache = value + mm_config.disable_mm_preprocessor_cache = value def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( @@ -1702,8 +1692,7 @@ def processor_return_mm_hashes(self) -> bool: if mm_config is None: return False - return (not mm_config.disable_mm_preprocessor_cache - or mm_config.mm_ipc_cache_gb > 0) + return not mm_config.disable_mm_preprocessor_cache @property def enable_mm_ipc_cache(self) -> bool: @@ -1712,14 +1701,14 @@ def enable_mm_ipc_cache(self) -> bool: if mm_config is None: return False - return mm_config.mm_ipc_cache_gb > 0 + return not mm_config.disable_mm_preprocessor_cache def get_mm_ipc_cache_gb(self) -> int: mm_config = self.multimodal_config if mm_config is None: return 0 - return mm_config.mm_ipc_cache_gb + return envs.VLLM_MM_INPUT_CACHE_GIB @property def is_cross_encoder(self) -> bool: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 25338f8459c6..3cb9db0e48fa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -360,7 +360,6 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = \ MultiModalConfig.disable_mm_preprocessor_cache - mm_ipc_cache_gb: int = MultiModalConfig.mm_ipc_cache_gb # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -723,8 +722,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--disable-mm-preprocessor-cache", **multimodal_kwargs["disable_mm_preprocessor_cache"]) - multimodal_group.add_argument("--mm-ipc-cache-gb", - **multimodal_kwargs["mm_ipc_cache_gb"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -926,7 +923,6 @@ def create_model_config(self) -> ModelConfig: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, - mm_ipc_cache_gb=self.mm_ipc_cache_gb, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1238,13 +1234,13 @@ def create_engine_config( dp_supports_mm_ipc_cache = (self.data_parallel_size == 1 or data_parallel_external_lb) if (not dp_supports_mm_ipc_cache - and model_config.mm_ipc_cache_gb > 0): + and not model_config.disable_mm_preprocessor_cache): logger.warning( - "Multi-modal IPC cache is disabled because " + "Multi-modal processor cache is disabled because " "it is not compatible with data parallelism when " "there does not exist a one-to-one correspondance " "between API and engine core processes.") - model_config.set_mm_ipc_cache_gb(0) + model_config.set_disable_mm_preprocessor_cache(True) speculative_config = self.create_speculative_config( target_model_config=model_config, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b3ed79e67cef..91fd32aa36ef 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers = args.api_server_count assert num_api_servers > 0 - orig_mm_ipc_cache_gb = args.mm_ipc_cache_gb + orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache if num_api_servers > 1: setup_multiprocess_prometheus() # Not compatible with API server scale-out - args.mm_ipc_cache_gb = 0 + args.disable_mm_preprocessor_cache = 0 listen_address, sock = setup_server(args) @@ -161,9 +161,10 @@ def run_multi_api_server(args: argparse.Namespace): raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " "with api_server_count > 1") - if model_config.is_multimodal_model and orig_mm_ipc_cache_gb > 0: - logger.warning("Multi-modal IPC cache is disabled because it " - "is not compatible with `api_server_count > 1`.") + if model_config.is_multimodal_model and not ( + orig_disable_mm_preprocessor_cache): + logger.warning("Multi-modal processor cache is disabled because " + "it is not compatible with `api_server_count > 1`.") executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/envs.py b/vllm/envs.py index b51dd364285a..18eca6364036 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -552,8 +552,8 @@ def get_vllm_port() -> Optional[int]: "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - # Cache size (in GiB per API process) for multimodal input cache - # Default is 4 GiB per API process + # Cache size (in GiB per process) for multimodal input cache + # Default is 4 GiB per API process + 4 GiB per engine core process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ad26f20aa3a4..ad5b0254d340 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -430,7 +430,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, raise ValueError( "The number of multi-modal positions and hashes must match. This " "is likely because you did not enable MM IPC caching. " - "Please set `mm_ipc_cache_gb > 0`.") + "Please set `disable_mm_preprocessor_cache=False`.") # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 467db42678e4..c734c019d1d4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCache +from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -124,7 +124,7 @@ def __init__(self, log_stats=self.log_stats, ) - self.mm_ipc_cache = MultiModalIPCCache(vllm_config.model_config) + self.mm_ipc_cache = MultiModalIPCCacheServer(vllm_config.model_config) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously diff --git a/vllm/v1/engine/mm_ipc_cache.py b/vllm/v1/engine/mm_ipc_cache.py index 133823c34026..6038033ac584 100644 --- a/vllm/v1/engine/mm_ipc_cache.py +++ b/vllm/v1/engine/mm_ipc_cache.py @@ -14,26 +14,25 @@ # a server, where the client executes in the frontend process (=P0) and the # server in the core process (=P1). # -# -- Client: +# -- P0: # - BaseMultiModalProcessor processes MultiModalData into MultiModalKwargs # while outputting mm_hash as its identifier. -# - MultiModalIPCCacheLookup keeps track of the cached entries and +# - MultiModalIPCCacheClient keeps track of the cached entries and # determine whether to send the MultiModalKwargs to P1. # -# -- Server: -# - MultiModalIPCCache stores the MultiModalKwargs from P0. +# -- P1: +# - MultiModalIPCCacheServer stores the MultiModalKwargs from P0. # -# The keys of MultiModalIPCCacheLookup (in the client) -# and the keys of MultiModalIPCCache (in the server) +# The keys of MultiModalIPCCacheClient and MultiModalIPCCacheServer # are mirrored, and this allows us to avoid the serialization of `mm_inputs` # (like pixel values) between client (=P0) and server (=P1) processes if the # `mm_hash` is found in the client cache. # Both Client and Server must use the same cache size (to remain mirrored). -# This cache size is set by the config variable `mm_ipc_cache_gb`. +# This cache size is set by the env variable `VLLM_MM_INPUT_CACHE_GIB`. -class MultiModalIPCCacheLookup: +class MultiModalIPCCacheClient: """Used by P0 to check whether multi-modal kwargs are cached in P1.""" def __init__(self, model_config: "ModelConfig") -> None: @@ -72,7 +71,7 @@ def reset(self) -> None: self.mm_cache.clear() -class MultiModalIPCCache: +class MultiModalIPCCacheServer: """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" def __init__(self, model_config: "ModelConfig") -> None: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 3b2d79cd07fc..a019c549a53b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,7 +19,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheLookup +from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_outlines import ( @@ -50,7 +50,7 @@ def __init__( self.tokenizer, mm_registry) - self.mm_ipc_cache = MultiModalIPCCacheLookup(self.model_config) + self.mm_ipc_cache = MultiModalIPCCacheClient(self.model_config) @property def mm_registry(self): From 317c19f2d7d20d384cfb1bad001b7193290bb0ff Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:00:27 +0000 Subject: [PATCH 08/14] Reduce diff Signed-off-by: DarkLight1337 --- vllm/v1/engine/core.py | 2 +- vllm/v1/engine/{mm_ipc_cache.py => mm_input_cache.py} | 0 vllm/v1/engine/processor.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename vllm/v1/engine/{mm_ipc_cache.py => mm_input_cache.py} (100%) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c734c019d1d4..bd6e18352d26 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheServer +from vllm.v1.engine.mm_input_cache import MultiModalIPCCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig diff --git a/vllm/v1/engine/mm_ipc_cache.py b/vllm/v1/engine/mm_input_cache.py similarity index 100% rename from vllm/v1/engine/mm_ipc_cache.py rename to vllm/v1/engine/mm_input_cache.py diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a019c549a53b..77c0c2bfc1b8 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,7 +19,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_ipc_cache import MultiModalIPCCacheClient +from vllm.v1.engine.mm_input_cache import MultiModalIPCCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_outlines import ( From 7cd00239064ce4929d7bdec227e6acaffc537768 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:03:31 +0000 Subject: [PATCH 09/14] Reduce diff Signed-off-by: DarkLight1337 --- vllm/config.py | 16 ++-------------- vllm/engine/arg_utils.py | 6 +++--- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core.py | 11 ++++++----- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/engine/mm_input_cache.py | 18 +++++++++--------- vllm/v1/engine/processor.py | 6 +++--- 7 files changed, 25 insertions(+), 36 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 337e6a67cb49..e0dc5b0af724 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1695,7 +1695,7 @@ def processor_return_mm_hashes(self) -> bool: return not mm_config.disable_mm_preprocessor_cache @property - def enable_mm_ipc_cache(self) -> bool: + def enable_mm_input_cache(self) -> bool: """Whether the multi-modal IPC cache should be enabled.""" mm_config = self.multimodal_config if mm_config is None: @@ -1703,7 +1703,7 @@ def enable_mm_ipc_cache(self) -> bool: return not mm_config.disable_mm_preprocessor_cache - def get_mm_ipc_cache_gb(self) -> int: + def get_mm_input_cache_gb(self) -> int: mm_config = self.multimodal_config if mm_config is None: return 0 @@ -3390,18 +3390,6 @@ class MultiModalConfig: If `True`, disable caching of the multi-modal processor. """ - mm_ipc_cache_gb: int = 4 - """ - The size (in GiB) of the multi-modal IPC cache, which is used to avoid - transfer of past multi-modal inputs between API and engine core processes. - - Since the cache is located in engine core, it is duplicated for each - engine core process, resulting in a total memory usage of - `mm_ipc_cache_gb * data_parallel_size`. - - Set to `0` to disable this cache completely (not recommended). - """ - interleave_mm_strings: bool = False """ Enable fully interleaved support for multimodal prompts. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3cb9db0e48fa..c63a67d492f5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1231,9 +1231,9 @@ def create_engine_config( ) if model_config.is_multimodal_model: - dp_supports_mm_ipc_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_ipc_cache + dp_supports_mm_input_cache = (self.data_parallel_size == 1 + or data_parallel_external_lb) + if (not dp_supports_mm_input_cache and not model_config.disable_mm_preprocessor_cache): logger.warning( "Multi-modal processor cache is disabled because " diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 049b17031ec9..93b85e32771e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -567,7 +567,7 @@ async def stop_profile(self) -> None: async def reset_mm_cache(self) -> None: self.processor.mm_registry.reset_processor_cache() - self.processor.mm_ipc_cache.reset() + self.processor.mm_input_cache.reset() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bd6e18352d26..98a73a1fe0b7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalIPCCacheServer +from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -124,7 +124,8 @@ def __init__(self, log_stats=self.log_stats, ) - self.mm_ipc_cache = MultiModalIPCCacheServer(vllm_config.model_config) + self.mm_input_cache = MultiModalInputCacheServer( + vllm_config.model_config) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -345,7 +346,7 @@ def reset_mm_cache(self): logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_ipc_cache.reset() + self.mm_input_cache.reset() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -409,9 +410,9 @@ def preprocess_add_request( if request.mm_hashes is not None: assert request.mm_inputs is not None # Note on thread safety: no race condition. - # `mm_ipc_cache` is reset at the end of LLMEngine init, + # `mm_input_cache` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_inputs = self.mm_ipc_cache.get_and_update( + request.mm_inputs = self.mm_input_cache.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 10b2e380f0bb..c145d3c99e71 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -272,7 +272,7 @@ def stop_profile(self): def reset_mm_cache(self): self.processor.mm_registry.reset_processor_cache() - self.processor.mm_ipc_cache.reset() + self.processor.mm_input_cache.reset() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 6038033ac584..467ca82f61f4 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -17,13 +17,13 @@ # -- P0: # - BaseMultiModalProcessor processes MultiModalData into MultiModalKwargs # while outputting mm_hash as its identifier. -# - MultiModalIPCCacheClient keeps track of the cached entries and +# - MultiModalInputCacheClient keeps track of the cached entries and # determine whether to send the MultiModalKwargs to P1. # # -- P1: -# - MultiModalIPCCacheServer stores the MultiModalKwargs from P0. +# - MultiModalInputCacheServer stores the MultiModalKwargs from P0. # -# The keys of MultiModalIPCCacheClient and MultiModalIPCCacheServer +# The keys of MultiModalInputCacheClient and MultiModalInputCacheServer # are mirrored, and this allows us to avoid the serialization of `mm_inputs` # (like pixel values) between client (=P0) and server (=P1) processes if the # `mm_hash` is found in the client cache. @@ -32,15 +32,15 @@ # This cache size is set by the env variable `VLLM_MM_INPUT_CACHE_GIB`. -class MultiModalIPCCacheClient: +class MultiModalInputCacheClient: """Used by P0 to check whether multi-modal kwargs are cached in P1.""" def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - self.enabled = model_config.enable_mm_ipc_cache + self.enabled = model_config.enable_mm_input_cache self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_ipc_cache_gb(), + model_config.get_mm_input_cache_gb(), MultiModalCacheItemMetadata, ) @@ -71,15 +71,15 @@ def reset(self) -> None: self.mm_cache.clear() -class MultiModalIPCCacheServer: +class MultiModalInputCacheServer: """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - self.enabled = model_config.enable_mm_ipc_cache + self.enabled = model_config.enable_mm_input_cache self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_ipc_cache_gb(), + model_config.get_mm_input_cache_gb(), MultiModalKwargs, ) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 77c0c2bfc1b8..9d89a8b0e9ea 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,7 +19,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalIPCCacheClient +from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_outlines import ( @@ -50,7 +50,7 @@ def __init__( self.tokenizer, mm_registry) - self.mm_ipc_cache = MultiModalIPCCacheClient(self.model_config) + self.mm_input_cache = MultiModalInputCacheClient(self.model_config) @property def mm_registry(self): @@ -337,7 +337,7 @@ def process_inputs( ] if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_ipc_cache.get_and_update( + sorted_mm_inputs = self.mm_input_cache.get_and_update( orig_sorted_mm_inputs, sorted_mm_hashes) else: sorted_mm_inputs = orig_sorted_mm_inputs From e45a503596eacc3a09a3b9dc28d69d64cf780766 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:05:32 +0000 Subject: [PATCH 10/14] Reduce diff Signed-off-by: DarkLight1337 --- vllm/engine/arg_utils.py | 6 +++--- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/core.py | 8 ++++---- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/engine/processor.py | 5 +++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c63a67d492f5..f063fc78c2e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1231,9 +1231,9 @@ def create_engine_config( ) if model_config.is_multimodal_model: - dp_supports_mm_input_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_input_cache + dp_supports_mm_processor_cache = (self.data_parallel_size == 1 + or data_parallel_external_lb) + if (not dp_supports_mm_processor_cache and not model_config.disable_mm_preprocessor_cache): logger.warning( "Multi-modal processor cache is disabled because " diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 93b85e32771e..45f450291ab6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -567,7 +567,7 @@ async def stop_profile(self) -> None: async def reset_mm_cache(self) -> None: self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache.reset() + self.processor.mm_input_cache_client.reset() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 98a73a1fe0b7..78b8fe4ea676 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -124,7 +124,7 @@ def __init__(self, log_stats=self.log_stats, ) - self.mm_input_cache = MultiModalInputCacheServer( + self.mm_input_cache_server = MultiModalInputCacheServer( vllm_config.model_config) # Setup batch queue for pipeline parallelism. @@ -346,7 +346,7 @@ def reset_mm_cache(self): logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache.reset() + self.mm_input_cache_server.reset() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -410,9 +410,9 @@ def preprocess_add_request( if request.mm_hashes is not None: assert request.mm_inputs is not None # Note on thread safety: no race condition. - # `mm_input_cache` is reset at the end of LLMEngine init, + # `mm_input_cache_server` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_inputs = self.mm_input_cache.get_and_update( + request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c145d3c99e71..efbdffbc0900 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -272,7 +272,7 @@ def stop_profile(self): def reset_mm_cache(self): self.processor.mm_registry.reset_processor_cache() - self.processor.mm_input_cache.reset() + self.processor.mm_input_cache_client.reset() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9d89a8b0e9ea..6e37ebeb8778 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -50,7 +50,8 @@ def __init__( self.tokenizer, mm_registry) - self.mm_input_cache = MultiModalInputCacheClient(self.model_config) + self.mm_input_cache_client = MultiModalInputCacheClient( + self.model_config) @property def mm_registry(self): @@ -337,7 +338,7 @@ def process_inputs( ] if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache.get_and_update( + sorted_mm_inputs = self.mm_input_cache_client.get_and_update( orig_sorted_mm_inputs, sorted_mm_hashes) else: sorted_mm_inputs = orig_sorted_mm_inputs From 5f6d9028a52853bb95e79c46237b892e0cc8ed46 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:10:16 +0000 Subject: [PATCH 11/14] Fix Signed-off-by: DarkLight1337 --- examples/offline_inference/vision_language.py | 5 ----- vllm/entrypoints/cli/serve.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 373a98a8d666..5dbe00199428 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1567,11 +1567,6 @@ def parse_args(): action="store_true", help="If True, disables caching of multi-modal processor.", ) - parser.add_argument( - "--disable-mm-ipc-cache", - action="store_true", - help="If True, disables caching of multi-modal transfer from P0 to P1.", - ) parser.add_argument( "--time-generate", diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 91fd32aa36ef..02b78f103c5a 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -144,7 +144,7 @@ def run_multi_api_server(args: argparse.Namespace): setup_multiprocess_prometheus() # Not compatible with API server scale-out - args.disable_mm_preprocessor_cache = 0 + args.disable_mm_preprocessor_cache = True listen_address, sock = setup_server(args) From b3fea687c5501f0da48db7b70f424c42539b2f7d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:11:48 +0000 Subject: [PATCH 12/14] Update Signed-off-by: DarkLight1337 --- docs/configuration/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 75e77519c69f..bb7342c93fb9 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -150,7 +150,7 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 API server scale-out is only available for online inference. !!! note - [Multi-modal IPC cache](#ipc-cache) is disabled when API server scale-out is enabled + [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled because it requires a one-to-one correspondance between API and engine core processes. ## Multi-Modal Caching From 8c46e93929f0b69bc999e939e4cd3557074c2fe1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 02:15:19 +0000 Subject: [PATCH 13/14] Fix Signed-off-by: DarkLight1337 --- vllm/config.py | 2 +- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/engine/mm_input_cache.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e0dc5b0af724..038128d2b8c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1696,7 +1696,7 @@ def processor_return_mm_hashes(self) -> bool: @property def enable_mm_input_cache(self) -> bool: - """Whether the multi-modal IPC cache should be enabled.""" + """Whether the multi-modal input cache should be enabled.""" mm_config = self.multimodal_config if mm_config is None: return False diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ad5b0254d340..38b1d9b13fda 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -429,7 +429,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( "The number of multi-modal positions and hashes must match. This " - "is likely because you did not enable MM IPC caching. " + "is likely because you did not enable MM hashing. " "Please set `disable_mm_preprocessor_cache=False`.") # Note that we assume mm_positions is sorted by offset. diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 467ca82f61f4..bbcc2f08e341 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from vllm.config import ModelConfig -# The idea of multimodal IPC caching is based on having a client and +# The idea of multimodal input caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the # server in the core process (=P1). # From 6d7253f98045335676b8495785ab9290ccacafad Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 08:42:28 +0000 Subject: [PATCH 14/14] Update code comment Signed-off-by: DarkLight1337 --- vllm/v1/engine/mm_input_cache.py | 35 +++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index bbcc2f08e341..279c9f0007bc 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -15,21 +15,32 @@ # server in the core process (=P1). # # -- P0: -# - BaseMultiModalProcessor processes MultiModalData into MultiModalKwargs -# while outputting mm_hash as its identifier. -# - MultiModalInputCacheClient keeps track of the cached entries and -# determine whether to send the MultiModalKwargs to P1. +# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of +# each input multi-modal item (e.g. image), +# - BaseMultiModalProcessor processes the input items into `mm_inputs`, +# which are MultiModalKwargsItem instances that each correspond to an +# input multi-modal item. +# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding +# `mm_hash` for each item. It stores the `mm_hash` as keys and the size +# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking +# up additional memory in P0. +# - The `mm_hash` is always sent to P1. +# - The corresponding `mm_inputs` are only sent to P1 if they are not cached +# in MultiModalInputCacheServer. # # -- P1: -# - MultiModalInputCacheServer stores the MultiModalKwargs from P0. +# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0), +# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`. +# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0), +# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`. +# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to +# the engine for model execution. # -# The keys of MultiModalInputCacheClient and MultiModalInputCacheServer -# are mirrored, and this allows us to avoid the serialization of `mm_inputs` -# (like pixel values) between client (=P0) and server (=P1) processes if the -# `mm_hash` is found in the client cache. - -# Both Client and Server must use the same cache size (to remain mirrored). -# This cache size is set by the env variable `VLLM_MM_INPUT_CACHE_GIB`. +# Both Client and Server must perform cache update and eviction based on the +# same item size. This ensures that the keys of MultiModalInputCacheClient +# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 +# whether a key is cached in MultiModalInputCacheServer by querying +# MultiModalInputCacheClient without having to communicate with P1. class MultiModalInputCacheClient: