Skip to content

Commit 7af83cd

Browse files
ywang96lulmer
authored andcommitted
[V1] Consolidate MM cache size to vllm.envs (vllm-project#13239)
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent f7447b2 commit 7af83cd

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

vllm/envs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
5656
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
5757
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
58+
VLLM_MM_INPUT_CACHE_SIZE: int = 256
5859
VLLM_TARGET_DEVICE: str = "cuda"
5960
MAX_JOBS: Optional[str] = None
6061
NVCC_THREADS: Optional[str] = None
@@ -401,15 +402,21 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
401402
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
402403

403404
# Timeout for fetching videos when serving multimodal models
404-
# Default is 15 seconds
405+
# Default is 30 seconds
405406
"VLLM_VIDEO_FETCH_TIMEOUT":
406-
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")),
407+
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")),
407408

408409
# Timeout for fetching audio when serving multimodal models
409410
# Default is 10 seconds
410411
"VLLM_AUDIO_FETCH_TIMEOUT":
411412
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
412413

414+
# Cache size for multimodal feature/input cache for multimodal models
415+
# in unit of number of multimodal data items (e.g. image, video, audio).
416+
# Default is 256 multimodal data items.
417+
"VLLM_MM_INPUT_CACHE_SIZE":
418+
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")),
419+
413420
# Path to the XLA persistent cache directory.
414421
# Only used for XLA devices such as TPUs.
415422
"VLLM_XLA_CACHE_PATH":

vllm/multimodal/registry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch.nn as nn
1010

11+
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
1112
from vllm.inputs import InputProcessingContext
1213
from vllm.logger import init_logger
1314
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -28,9 +29,6 @@
2829

2930
logger = init_logger(__name__)
3031

31-
# TODO: Tune the MM cache size
32-
MM_CACHE_SIZE = 256
33-
3432
N = TypeVar("N", bound=Type[nn.Module])
3533
_I = TypeVar("_I", bound=BaseProcessingInfo)
3634
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
@@ -121,7 +119,7 @@ def __init__(
121119

122120
self._limits_by_model = _MultiModalLimits()
123121

124-
self._processing_cache = ProcessingCache(MM_CACHE_SIZE)
122+
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE)
125123

126124
def register_plugin(self, plugin: MultiModalPlugin) -> None:
127125
"""

vllm/v1/engine/mm_input_cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List, Optional
44

55
from vllm.config import ModelConfig
6+
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
67
from vllm.logger import init_logger
78
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
89
MultiModalKwargs, MultiModalRegistry)
@@ -28,9 +29,8 @@
2829
# client (=P0) and server (=P1) processes.
2930

3031
# Both Client and Server must use the same cache size
31-
# (to perform mirrored caching)
32-
# TODO: Tune the MM cache size
33-
MM_CACHE_SIZE = 256
32+
# (to perform mirrored caching). This cache size is set by the environment
33+
# variable VLLM_MM_INPUT_CACHE_SIZE.
3434

3535

3636
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
@@ -50,7 +50,8 @@ def __init__(
5050

5151
# Init cache
5252
self.use_cache = not model_config.disable_mm_preprocessor_cache
53-
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
53+
self.mm_cache = LRUCache[str,
54+
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
5455

5556
# DEBUG: Set to None to disable
5657
self.mm_debug_cache_hit_ratio_steps = None
@@ -127,7 +128,8 @@ class MMInputCacheServer:
127128

128129
def __init__(self, model_config):
129130
self.use_cache = not model_config.disable_mm_preprocessor_cache
130-
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
131+
self.mm_cache = LRUCache[str,
132+
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
131133

132134
def get_and_update(
133135
self,

0 commit comments

Comments
 (0)