diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index dcaf1069bfdf..058eba5fe0b1 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index bb7342c93fb9..2eeb8ad25de5 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -161,12 +161,18 @@ By default, the multi-modal processor cache is enabled to avoid repeatedly proce the same multi-modal inputs via Hugging Face `AutoProcessor`, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB per API process + 4 GiB per engine core process). +If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. -If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: +Examples: ```python +# Use a larger cache llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - disable_mm_preprocessor_cache=True) + mm_processor_cache_gb=8) + +# Disable the cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_cache_gb=0) ``` diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index 59ec22a1e9fa..1f6e5ba1467c 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace): max_model_len=4096, max_num_seqs=2, tensor_parallel_size=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4, ) prompt = "Describe this image in one sentence." @@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace): limit_mm_per_prompt={"image": max_img_per_msg}, max_model_len=max_img_per_msg * max_tokens_per_img, tensor_parallel_size=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4, ) prompt = "Describe the following image." @@ -164,7 +164,7 @@ def parse_args(): ) parser.add_argument( - "--disable-mm-preprocessor-cache", + "--disable-mm-processor-cache", action="store_true", help="If True, disables caching of multi-modal processor.", ) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5dbe00199428..1314d33e9009 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1563,7 +1563,7 @@ def parse_args(): ) parser.add_argument( - "--disable-mm-preprocessor-cache", + "--disable-mm-processor-cache", action="store_true", help="If True, disables caching of multi-modal processor.", ) @@ -1603,7 +1603,7 @@ def main(args): engine_args = asdict(req_data.engine_args) | { "seed": args.seed, - "disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, + "mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4, } llm = LLM(**engine_args) diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index f65385150d75..a5d6948f06ef 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -62,9 +62,7 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_runner_kwargs_: dict[str, Any] = { - "disable_mm_preprocessor_cache": True, - } + vllm_runner_kwargs_: dict[str, Any] = {"mm_processor_cache_gb": 0} if model_info.tokenizer: vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer if model_info.tokenizer_mode: diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 9ef7af556291..5e14f0f9964d 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -15,14 +15,14 @@ ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) @pytest.mark.parametrize("mm_processor_kwargs", [{}]) @pytest.mark.parametrize("num_imgs", [1, 5]) -@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) +@pytest.mark.parametrize("mm_processor_cache_gb", [0, 4]) @pytest.mark.parametrize("tokenized_prompt", [True, False]) def test_processor_override( image_assets: ImageTestAssets, model_id: str, mm_processor_kwargs: dict, num_imgs: int, - disable_mm_preprocessor_cache: bool, + mm_processor_cache_gb: int, tokenized_prompt: bool, ): """Ensure llama4 processor works properly.""" @@ -30,7 +30,7 @@ def test_processor_override( model_id, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": num_imgs}, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + mm_processor_cache_gb=mm_processor_cache_gb, ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) config = processor.info.get_hf_config() diff --git a/tests/models/utils.py b/tests/models/utils.py index 27ce9de46934..1e3d51aeec64 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -261,7 +261,7 @@ def build_model_context( model_config_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, - disable_mm_preprocessor_cache: bool = True, + mm_processor_cache_gb: int = 0, ): """Creates an InputContext for a given model. @@ -291,7 +291,7 @@ def build_model_context( seed=0, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt=limit_mm_per_prompt, - disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, + mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, **model_config_kwargs, ) diff --git a/vllm/config.py b/vllm/config.py index 44a8d871f0db..8dcd429a6b33 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -443,8 +443,15 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - disable_mm_preprocessor_cache: bool = False - """If `True`, disable caching of the multi-modal processor.""" + mm_processor_cache_gb: int = 4 + """The size (in GiB) of the multi-modal processor cache, which is used to + avoid re-processing past multi-modal inputs. + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended).""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -881,17 +888,16 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, - disable_mm_preprocessor_cache=self. - disable_mm_preprocessor_cache, + mm_processor_cache_gb=self.mm_processor_cache_gb, interleave_mm_strings=self.interleave_mm_strings) return None - def set_disable_mm_preprocessor_cache(self, value: bool) -> None: + def set_mm_processor_cache_gb(self, value: int) -> None: mm_config = self.get_multimodal_config() - self.disable_mm_preprocessor_cache = value - mm_config.disable_mm_preprocessor_cache = value + self.mm_processor_cache_gb = value + mm_config.mm_processor_cache_gb = value def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( @@ -1698,7 +1704,16 @@ def processor_return_mm_hashes(self) -> bool: if mm_config is None: return False - return not mm_config.disable_mm_preprocessor_cache + return mm_config.mm_processor_cache_gb > 0 + + @property + def enable_mm_processor_cache(self) -> bool: + """Whether the multi-modal processor cache should be enabled.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return mm_config.mm_processor_cache_gb > 0 @property def enable_mm_input_cache(self) -> bool: @@ -1707,7 +1722,7 @@ def enable_mm_input_cache(self) -> bool: if mm_config is None: return False - return not mm_config.disable_mm_preprocessor_cache + return mm_config.mm_processor_cache_gb > 0 def get_mm_input_cache_gb(self) -> int: mm_config = self.multimodal_config @@ -3391,9 +3406,15 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - disable_mm_preprocessor_cache: bool = False + mm_processor_cache_gb: int = 4 """ - If `True`, disable caching of the multi-modal processor. + The size (in GiB) of the multi-modal processor cache, which is used to + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended). """ interleave_mm_strings: bool = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a18cd9dde391..d2153dfae341 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -358,8 +358,8 @@ class EngineArgs: "media_io_kwargs") mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs - disable_mm_preprocessor_cache: bool = \ - MultiModalConfig.disable_mm_preprocessor_cache + disable_mm_preprocessor_cache: bool = False # DEPRECATED + mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled @@ -720,8 +720,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"]) multimodal_group.add_argument( - "--disable-mm-preprocessor-cache", - **multimodal_kwargs["disable_mm_preprocessor_cache"]) + "--mm-processor-cache-gb", + **multimodal_kwargs["mm_processor_cache_gb"]) + multimodal_group.add_argument("--disable-mm-preprocessor-cache", + type=bool, + deprecated=True) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -886,6 +889,23 @@ def create_model_config(self) -> ModelConfig: self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" self.load_format = "runai_streamer" + if self.disable_mm_preprocessor_cache: + logger.warning( + "`--disable-mm-preprocessor-cache` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-processor-cache-gb 0` instead.", ) + + self.mm_processor_cache_gb = 0 + elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: + logger.warning( + "VLLM_MM_INPUT_CACHE_GIB` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-processor-cache-gb %d` instead.", + envs.VLLM_MM_INPUT_CACHE_GIB, + ) + + self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -922,7 +942,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, - disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, + mm_processor_cache_gb=self.mm_processor_cache_gb, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1234,13 +1254,13 @@ def create_engine_config( dp_supports_mm_processor_cache = (self.data_parallel_size == 1 or data_parallel_external_lb) if (not dp_supports_mm_processor_cache - and not model_config.disable_mm_preprocessor_cache): + and model_config.mm_processor_cache_gb > 0): logger.warning( "Multi-modal processor cache is disabled because " "it is not compatible with data parallelism when " "there does not exist a one-to-one correspondance " "between API and engine core processes.") - model_config.set_disable_mm_preprocessor_cache(True) + model_config.set_mm_processor_cache_gb(0) speculative_config = self.create_speculative_config( target_model_config=model_config, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 02b78f103c5a..803a3e004656 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers = args.api_server_count assert num_api_servers > 0 - orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache + orig_mm_processor_cache_gb = args.mm_processor_cache_gb if num_api_servers > 1: setup_multiprocess_prometheus() # Not compatible with API server scale-out - args.disable_mm_preprocessor_cache = True + args.mm_processor_cache_gb = 0 listen_address, sock = setup_server(args) @@ -161,8 +161,7 @@ def run_multi_api_server(args: argparse.Namespace): raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " "with api_server_count > 1") - if model_config.is_multimodal_model and not ( - orig_disable_mm_preprocessor_cache): + if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0: logger.warning("Multi-modal processor cache is disabled because " "it is not compatible with `api_server_count > 1`.") diff --git a/vllm/envs.py b/vllm/envs.py index 212eaf015a83..8b12a7ee2b98 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -561,7 +561,7 @@ def get_vllm_port() -> Optional[int]: "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - # Cache size (in GiB per process) for multimodal input cache + # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache # Default is 4 GiB per API process + 4 GiB per engine core process "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5f5b620e0cf7..dca04e9a1e22 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,7 +6,6 @@ import torch.nn as nn -from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, @@ -96,11 +95,22 @@ def __init__(self) -> None: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + self._processor_cache: Optional[ProcessingCache] = None + + def _get_processor_cache(self, model_config: "ModelConfig"): + capacity_gb = model_config.mm_processor_cache_gb + if capacity_gb is None: + return None # Overrides `disable_cache` argument + + if self._processor_cache is None: + self._processor_cache = ProcessingCache(capacity_gb) + + return self._processor_cache def reset_processor_cache(self) -> bool: """Reset the multi-modal processing cache.""" - self._processing_cache.reset() + if self._processor_cache: + self._processor_cache.reset() return True # Success @@ -244,14 +254,14 @@ def create_processor( if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: - mm_config = model_config.get_multimodal_config() - disable_cache = mm_config.disable_mm_preprocessor_cache + disable_cache = not model_config.enable_mm_processor_cache model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - cache = None if disable_cache else self._processing_cache + cache = None if disable_cache else self._get_processor_cache( + model_config) return factories.build_processor(ctx, cache=cache) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 38b1d9b13fda..626aa35a770c 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -430,7 +430,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, raise ValueError( "The number of multi-modal positions and hashes must match. This " "is likely because you did not enable MM hashing. " - "Please set `disable_mm_preprocessor_cache=False`.") + "Please set `mm_processor_cache_gb > 0`.") # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of