diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 4d5c961af98f..dcaf1069bfdf 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). +- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process) - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits @@ -129,20 +129,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory. Here are some examples: -??? code - - ```python - from vllm import LLM +```python +from vllm import LLM - # Available for Qwen2-VL series models - llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_kwargs={ - "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 - }) - - # Available for InternVL series models - llm = LLM(model="OpenGVLab/InternVL2-2B", - mm_processor_kwargs={ - "max_dynamic_patch": 4, # Default is 12 - }) - ``` +# Available for Qwen2-VL series models +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={ + "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 + }) + +# Available for InternVL series models +llm = LLM(model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={ + "max_dynamic_patch": 4, # Default is 12 + }) +``` diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 811925c19e63..bb7342c93fb9 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -2,6 +2,9 @@ This guide covers optimization strategies and performance tuning for vLLM V1. +!!! tip + Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory. + ## Preemption Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. @@ -126,62 +129,44 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. -## Reducing Memory Usage - -If you encounter out-of-memory issues, consider these strategies: +## Input Processing -### Context Length and Batch Size +### Parallel Processing -You can reduce memory usage by limiting the context length and batch size: +You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). +This is useful when input processing (which is run inside the API server) +becomes a bottleneck compared to model execution (which is run inside engine core) +and you have excess CPU capacity. -```python -from vllm import LLM +```console +# Run 4 API processes and 1 engine core process +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - max_model_len=2048, # Limit context window - max_num_seqs=4 # Limit batch size -) +# Run 4 API processes and 2 engine core processes +vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 ``` -### Adjust CUDA Graph Compilation +!!! note + API server scale-out is only available for online inference. -CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: - -```python -from vllm import LLM -from vllm.config import CompilationConfig, CompilationLevel - -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes - ) -) -``` +!!! note + [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled + because it requires a one-to-one correspondance between API and engine core processes. -Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: +## Multi-Modal Caching -```python -from vllm import LLM +### Processor Cache -llm = LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True # Disable CUDA graph compilation -) -``` +By default, the multi-modal processor cache is enabled to avoid repeatedly processing +the same multi-modal inputs via Hugging Face `AutoProcessor`, +which commonly occurs in multi-turn conversations. -### Multimodal Models +You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable +(default 4 GiB per API process + 4 GiB per engine core process). -For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: +If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: ```python -from vllm import LLM - -# Accept up to 2 images per prompt -llm = LLM( - model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"image": 2} -) +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + disable_mm_preprocessor_cache=True) ``` diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index a38fc9216d40..59ec22a1e9fa 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -166,7 +166,7 @@ def parse_args(): parser.add_argument( "--disable-mm-preprocessor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", ) return parser.parse_args() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 16bb3712f551..5dbe00199428 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1565,7 +1565,7 @@ def parse_args(): parser.add_argument( "--disable-mm-preprocessor-cache", action="store_true", - help="If True, disables caching of multi-modal preprocessor/mapper.", + help="If True, disables caching of multi-modal processor.", ) parser.add_argument( diff --git a/tests/models/utils.py b/tests/models/utils.py index 3cd0721be1b6..0ba65532ab8c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from vllm.config import ModelConfig, RunnerOption +from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs @@ -256,7 +256,7 @@ def check_logprobs_close( def build_model_context( model_id: str, runner: RunnerOption = "auto", - dtype: Union[str, torch.dtype] = "auto", + dtype: ModelDType = "auto", model_config_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, @@ -278,6 +278,7 @@ def build_model_context( model_info.check_transformers_version(on_fail="skip") model_config_kwargs = model_config_kwargs or {} + limit_mm_per_prompt = limit_mm_per_prompt or {} model_config = ModelConfig( model_id, runner=runner, diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py new file mode 100644 index 000000000000..e07b73bd257d --- /dev/null +++ b/tests/multimodal/test_cache.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, + MultiModalSharedField) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size, ), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size) for key, size in size_by_key.items() + ]) + + +def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): + return MultiModalKwargs.from_items([ + _dummy_item(modality, size_by_key) + for modality, size_by_key in size_by_key_modality.items() + ]) + + +# yapf: disable +@pytest.mark.parametrize( + ("item", "expected_size"), + [ + (_dummy_item("a", {"a1": 100}), 100), + (_dummy_item("a", {"a1": 100, "a2": 110}), 210), + (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 + ], +) +# yapf: enable +def test_cache_item_size(item, expected_size): + cache = MultiModalCache.get_lru_cache(2048, type(item)) + + cache[""] = item + assert cache.currsize == expected_size + + cache[""] = MultiModalCacheItemMetadata.wraps(item) + assert cache.currsize == expected_size diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 508c773b8aed..cb489c47fd8f 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -6,20 +6,15 @@ import numpy as np import pytest -import torch from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalSharedField) # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (PlaceholderFeaturesInfo, - ProcessingCache, PromptIndexTargets, - PromptInsertion, PromptReplacement, - apply_text_matches, + PromptIndexTargets, PromptInsertion, + PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, find_text_matches, find_token_matches, @@ -902,45 +897,6 @@ def test_find_mm_placeholders( assert result == expected -def _dummy_elem(modality: str, key: str, size: int): - return MultiModalFieldElem( - modality=modality, - key=key, - data=torch.empty((size, ), dtype=torch.int8), - field=MultiModalSharedField(1), - ) - - -def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() - ]) - - -def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]): - return MultiModalKwargs.from_items([ - _dummy_item(modality, size_by_key) - for modality, size_by_key in size_by_key_modality.items() - ]) - - -# yapf: disable -@pytest.mark.parametrize( - ("item", "expected_size"), - [ - (_dummy_item("a", {"a1": 100}), 100), - (_dummy_item("a", {"a1": 100, "a2": 110}), 210), - (_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - ], -) -# yapf: enable -def test_cache_item_size(item, expected_size): - cache = ProcessingCache.get_lru_cache(2048, type(item)) - cache[""] = item - - assert cache.currsize == expected_size - - @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), diff --git a/vllm/config.py b/vllm/config.py index 5c300e327397..038128d2b8c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -444,8 +444,7 @@ class ModelConfig: model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ disable_mm_preprocessor_cache: bool = False - """If `True`, disable caching of the multi-modal preprocessor/mapper (not - recommended).""" + """If `True`, disable caching of the multi-modal processor.""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -1686,6 +1685,31 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def processor_return_mm_hashes(self) -> bool: + """Whether the multi-modal processor should output hashes.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return not mm_config.disable_mm_preprocessor_cache + + @property + def enable_mm_input_cache(self) -> bool: + """Whether the multi-modal input cache should be enabled.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return not mm_config.disable_mm_preprocessor_cache + + def get_mm_input_cache_gb(self) -> int: + mm_config = self.multimodal_config + if mm_config is None: + return 0 + + return envs.VLLM_MM_INPUT_CACHE_GIB + @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -3363,7 +3387,7 @@ class MultiModalConfig: disable_mm_preprocessor_cache: bool = False """ - If `True`, disable caching of the processed multi-modal inputs. + If `True`, disable caching of the multi-modal processor. """ interleave_mm_strings: bool = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5eb9660cd1e8..f063fc78c2e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1230,17 +1230,17 @@ def create_engine_config( enable_multimodal_encoder_data_parallel, ) - supports_mm_preprocessor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not supports_mm_preprocessor_cache - and model_config.is_multimodal_model - and not model_config.disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with data parallelism when there does not exist a " - "one-to-one correspondance between API process and " - "EngineCore process, so the cache will be disabled.") - model_config.set_disable_mm_preprocessor_cache(True) + if model_config.is_multimodal_model: + dp_supports_mm_processor_cache = (self.data_parallel_size == 1 + or data_parallel_external_lb) + if (not dp_supports_mm_processor_cache + and not model_config.disable_mm_preprocessor_cache): + logger.warning( + "Multi-modal processor cache is disabled because " + "it is not compatible with data parallelism when " + "there does not exist a one-to-one correspondance " + "between API and engine core processes.") + model_config.set_disable_mm_preprocessor_cache(True) speculative_config = self.create_speculative_config( target_model_config=model_config, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9762a1de9edd..02b78f103c5a 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -163,9 +163,8 @@ def run_multi_api_server(args: argparse.Namespace): if model_config.is_multimodal_model and not ( orig_disable_mm_preprocessor_cache): - logger.warning( - "Multi-modal preprocessor cache is not compatible " - "with api_server_count > 1, so the cache will be disabled.") + logger.warning("Multi-modal processor cache is disabled because " + "it is not compatible with `api_server_count > 1`.") executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/envs.py b/vllm/envs.py index 78f955f78a98..18eca6364036 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,7 +64,7 @@ VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" - VLLM_MM_INPUT_CACHE_GIB: int = 8 + VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -552,8 +552,8 @@ def get_vllm_port() -> Optional[int]: "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - # Cache size (in GiB) for multimodal input cache - # Default is 4 GiB + # Cache size (in GiB per process) for multimodal input cache + # Default is 4 GiB per API process + 4 GiB per engine core process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py new file mode 100644 index 000000000000..262b22e554b9 --- /dev/null +++ b/vllm/multimodal/cache.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import sys +from collections.abc import Mapping +from dataclasses import dataclass +from typing import TypeVar, Union + +import torch + +from vllm.jsontree import json_map_leaves, json_reduce_leaves +from vllm.logger import init_logger +from vllm.utils import GiB_bytes, LRUCache + +from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors + +logger = init_logger(__name__) + + +@dataclass +class MultiModalCacheItemMetadata: + size: int + + @classmethod + def wraps(cls, value: "MultiModalCacheValue"): + return cls(size=MultiModalCache.get_item_size(value)) + + +MultiModalCacheValue = Union[ + MultiModalKwargs, + MultiModalKwargsItem, + Mapping[str, NestedTensors], + MultiModalCacheItemMetadata, +] + +_V = TypeVar("_V", bound=MultiModalCacheValue) + + +class MultiModalCache: + + @classmethod + def get_leaf_size( + cls, + leaf: object, + *, + debug: bool = False, + ) -> int: + # MultiModalKwargs is not a subclass of dict + if isinstance(leaf, MultiModalKwargs): + return cls.get_item_size(leaf.data, debug=debug) + + # MultiModalKwargsItem is not a subclass of dict + if isinstance(leaf, MultiModalKwargsItem): + leaf_data = {k: v.data for k, v in leaf.items()} + return cls.get_item_size(leaf_data, debug=debug) + + # sys.getsizeof doesn't work for tensors + if isinstance(leaf, torch.Tensor): + return leaf.nbytes + + if isinstance(leaf, MultiModalCacheItemMetadata): + return leaf.size + + return sys.getsizeof(leaf) + + @classmethod + def get_item_size( + cls, + value: MultiModalCacheValue, + *, + debug: bool = False, + ) -> int: + size = json_reduce_leaves( + lambda a, b: a + b, + json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug), + value), + ) + + if debug: + logger.debug("Calculated size of %s to be %.2f GiB", type(value), + size / GiB_bytes) + + return size + + @classmethod + def get_lru_cache( + cls, + capacity_gb: float, + value_type: type[_V], + *, + debug: bool = False, + ) -> LRUCache[str, _V]: + return LRUCache( + GiB_bytes * capacity_gb, + getsizeof=lambda x: cls.get_item_size(x, debug=debug), + ) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 46240855d12a..0378539495fd 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import sys from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, @@ -16,16 +15,16 @@ from typing_extensions import assert_never from vllm.inputs import InputProcessingContext -from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby +from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby +from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, - MultiModalKwargsItem, NestedTensors, PlaceholderRange) + MultiModalKwargsItem, PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -888,9 +887,6 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") - - class ProcessingCacheOptionalItem(NamedTuple): key: str value: Optional[MultiModalKwargsItem] @@ -901,48 +897,7 @@ class ProcessingCacheItem(NamedTuple): value: MultiModalKwargsItem -class ProcessingCache: - - @staticmethod - def get_lru_cache( - capacity_gb: float, - value_type: type[_V], - *, - debug: bool = False, - ) -> LRUCache[str, _V]: - - def get_leaf_size(leaf: object) -> int: - # MultiModalKwargs is not a subclass of dict - if isinstance(leaf, MultiModalKwargs): - return get_item_size(leaf.data) - - # MultiModalKwargsItem is not a subclass of dict - if isinstance(leaf, MultiModalKwargsItem): - leaf_data = {k: v.data for k, v in leaf.items()} - return get_item_size(leaf_data) - - # sys.getsizeof doesn't work for tensors - if isinstance(leaf, torch.Tensor): - return leaf.nbytes - - return sys.getsizeof(leaf) - - def get_item_size( - value: Union[MultiModalKwargs, MultiModalKwargsItem, - Mapping[str, NestedTensors]] - ) -> int: - size = json_reduce_leaves( - lambda a, b: a + b, - json_map_leaves(get_leaf_size, value), - ) - - if debug: - logger.debug("Calculated size of %s to be %.2f GiB", - type(value), size / GiB_bytes) - - return size - - return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size) +class ProcessingCache(MultiModalCache): def __init__( self, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index eab1560b1a18..38b1d9b13fda 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, if mm_positions and len(mm_positions) != len(mm_hashes): raise ValueError( "The number of multi-modal positions and hashes must match. This " - "is likely because you do not enable MM preprocessor hashing. " - "Please set disable_mm_preprocessor_cache=False.") + "is likely because you did not enable MM hashing. " + "Please set `disable_mm_preprocessor_cache=False`.") # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 79c47e102888..78b8fe4ea676 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -35,7 +35,7 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache +from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -124,8 +124,7 @@ def __init__(self, log_stats=self.log_stats, ) - # Setup MM Input Mapper. - self.mm_input_cache_server = MirroredProcessingCache( + self.mm_input_cache_server = MultiModalInputCacheServer( vllm_config.model_config) # Setup batch queue for pipeline parallelism. @@ -413,7 +412,7 @@ def preprocess_add_request( # Note on thread safety: no race condition. # `mm_input_cache_server` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( + request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index abe98a13dfd3..279c9f0007bc 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,54 +1,68 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional +from typing import TYPE_CHECKING, Optional -from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.multimodal import MultiModalKwargs -from vllm.multimodal.processing import ProcessingCache +from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.utils import is_list_of -# The idea of multimodal preprocessing caching is based on having a client and +if TYPE_CHECKING: + from vllm.config import ModelConfig + +# The idea of multimodal input caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the # server in the core process (=P1). # -# -- Client: -# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs -# with built-in caching functionality, with mm_hash as its identifier. -# - MirroredProcessingCache to keep track of the cached entries and -# determine whether to send the MultiModalKwargs to P1. +# -- P0: +# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of +# each input multi-modal item (e.g. image), +# - BaseMultiModalProcessor processes the input items into `mm_inputs`, +# which are MultiModalKwargsItem instances that each correspond to an +# input multi-modal item. +# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding +# `mm_hash` for each item. It stores the `mm_hash` as keys and the size +# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking +# up additional memory in P0. +# - The `mm_hash` is always sent to P1. +# - The corresponding `mm_inputs` are only sent to P1 if they are not cached +# in MultiModalInputCacheServer. # -# -- Server: -# - MirroredProcessingCache to store the MultiModalKwargs from P0. +# -- P1: +# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0), +# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`. +# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0), +# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`. +# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to +# the engine for model execution. # -# The caching for both client and server is mirrored, and this allows us -# to avoid the serialization of "mm_inputs" (like pixel values) between -# client (=P0) and server (=P1) processes if the mm_hash is found in the client -# cache. +# Both Client and Server must perform cache update and eviction based on the +# same item size. This ensures that the keys of MultiModalInputCacheClient +# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 +# whether a key is cached in MultiModalInputCacheServer by querying +# MultiModalInputCacheClient without having to communicate with P1. -# Both Client and Server must use the same cache size -# (to perform mirrored caching). This cache size is set by the environment -# variable VLLM_MM_INPUT_CACHE_GIB. +class MultiModalInputCacheClient: + """Used by P0 to check whether multi-modal kwargs are cached in P1.""" -class MirroredProcessingCache: + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() - def __init__(self, model_config): - mm_config = model_config.multimodal_config - disable_mm_preprocessor_cache = ( - mm_config is not None and mm_config.disable_mm_preprocessor_cache) - self.use_cache = not disable_mm_preprocessor_cache - self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, - MultiModalKwargs) + self.enabled = model_config.enable_mm_input_cache + self.mm_cache = MultiModalCache.get_lru_cache( + model_config.get_mm_input_cache_gb(), + MultiModalCacheItemMetadata, + ) - def get_and_update_p0( + def get_and_update( self, mm_inputs: Sequence[MultiModalKwargs], mm_hashes: list[str], ) -> Sequence[Optional[MultiModalKwargs]]: assert len(mm_inputs) == len(mm_hashes) - if not self.use_cache: + if not self.enabled: assert is_list_of(mm_inputs, MultiModalKwargs) return mm_inputs @@ -57,20 +71,37 @@ def get_and_update_p0( if self.mm_cache.get(mm_hash) is not None: mm_input = None else: - self.mm_cache[mm_hash] = mm_input + self.mm_cache[mm_hash] = \ + MultiModalCacheItemMetadata.wraps(mm_input) full_mm_inputs.append(mm_input) return full_mm_inputs - def get_and_update_p1( + def reset(self) -> None: + self.mm_cache.clear() + + +class MultiModalInputCacheServer: + """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + self.enabled = model_config.enable_mm_input_cache + self.mm_cache = MultiModalCache.get_lru_cache( + model_config.get_mm_input_cache_gb(), + MultiModalKwargs, + ) + + def get_and_update( self, mm_inputs: Sequence[Optional[MultiModalKwargs]], mm_hashes: list[str], ) -> Sequence[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) - if not self.use_cache: + if not self.enabled: assert is_list_of(mm_inputs, MultiModalKwargs) return mm_inputs @@ -85,7 +116,5 @@ def get_and_update_p1( return full_mm_inputs - def reset(self) -> bool: + def reset(self) -> None: self.mm_cache.clear() - - return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 692a7dd5640e..6e37ebeb8778 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,7 +19,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MirroredProcessingCache +from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_outlines import ( @@ -50,11 +50,8 @@ def __init__( self.tokenizer, mm_registry) - self.mm_input_cache_client = MirroredProcessingCache(self.model_config) - - # Multi-modal hasher (for images) - self.use_hash = self.mm_input_cache_client.use_cache or \ - self.cache_config.enable_prefix_caching + self.mm_input_cache_client = MultiModalInputCacheClient( + self.model_config) @property def mm_registry(self): @@ -256,11 +253,13 @@ def process_inputs( # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. + return_mm_hashes = (self.model_config.processor_return_mm_hashes + or bool(self.cache_config.enable_prefix_caching)) processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=self.use_hash, + return_mm_hashes=return_mm_hashes, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -312,7 +311,7 @@ def process_inputs( sorted_mm_hashes, ) = merge_and_sort_multimodal_metadata( decoder_inputs["mm_placeholders"], - decoder_inputs["mm_hashes"] if self.use_hash else None, + decoder_inputs["mm_hashes"] if return_mm_hashes else None, ) # The output of merged multi-modal processor (`decoder_mm_inputs`) @@ -339,7 +338,7 @@ def process_inputs( ] if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( + sorted_mm_inputs = self.mm_input_cache_client.get_and_update( orig_sorted_mm_inputs, sorted_mm_hashes) else: sorted_mm_inputs = orig_sorted_mm_inputs