|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# Standard |
| 4 | +import os |
| 5 | +import threading |
| 6 | +from typing import TYPE_CHECKING, Union |
| 7 | + |
| 8 | +import torch |
| 9 | +from lmcache.config import LMCacheEngineConfig as Config |
| 10 | +from lmcache.logging import init_logger |
| 11 | +from lmcache.v1.config import LMCacheEngineConfig as V1Config |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from vllm.config import ModelConfig |
| 15 | + from vllm.multimodal.inputs import PlaceholderRange |
| 16 | + from vllm.v1.core.sched.output import NewRequestData |
| 17 | + from vllm.v1.request import Request |
| 18 | + |
| 19 | +logger = init_logger(__name__) |
| 20 | +ENGINE_NAME = "vllm-instance" |
| 21 | + |
| 22 | +# Thread-safe singleton storage |
| 23 | +_config_instance: Config | V1Config | None = None |
| 24 | +_config_lock = threading.Lock() |
| 25 | + |
| 26 | + |
| 27 | +def is_false(value: str) -> bool: |
| 28 | + """Check if the given string value is equivalent to 'false'.""" |
| 29 | + return value.lower() in ("false", "0", "no", "n", "off") |
| 30 | + |
| 31 | + |
| 32 | +def lmcache_get_or_create_config() -> Config | V1Config: |
| 33 | + """Get the LMCache configuration from the environment variable |
| 34 | + `LMCACHE_CONFIG_FILE`. If the environment variable is not set, this |
| 35 | + function will return the default configuration. |
| 36 | +
|
| 37 | + This function is thread-safe and implements singleton pattern, |
| 38 | + ensuring the configuration is loaded only once. |
| 39 | + """ |
| 40 | + global _config_instance |
| 41 | + |
| 42 | + # Double-checked locking for thread-safe singleton |
| 43 | + if _config_instance is None: |
| 44 | + with _config_lock: |
| 45 | + if _config_instance is None: # Check again within lock |
| 46 | + if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")): |
| 47 | + logger.warning( |
| 48 | + "Detected LMCACHE_USE_EXPERIMENTAL is set to False. " |
| 49 | + "Using legacy configuration is deprecated and will " |
| 50 | + "be remove soon! Please set LMCACHE_USE_EXPERIMENTAL " |
| 51 | + "to True." |
| 52 | + ) |
| 53 | + LMCacheEngineConfig = Config # type: ignore[assignment] |
| 54 | + else: |
| 55 | + LMCacheEngineConfig = V1Config # type: ignore[assignment] |
| 56 | + |
| 57 | + if "LMCACHE_CONFIG_FILE" not in os.environ: |
| 58 | + logger.warning( |
| 59 | + "No LMCache configuration file is set. Trying to read" |
| 60 | + " configurations from the environment variables." |
| 61 | + ) |
| 62 | + logger.warning( |
| 63 | + "You can set the configuration file through " |
| 64 | + "the environment variable: LMCACHE_CONFIG_FILE" |
| 65 | + ) |
| 66 | + _config_instance = LMCacheEngineConfig.from_env() |
| 67 | + else: |
| 68 | + config_file = os.environ["LMCACHE_CONFIG_FILE"] |
| 69 | + logger.info("Loading LMCache config file %s", config_file) |
| 70 | + _config_instance = LMCacheEngineConfig.from_file(config_file) |
| 71 | + # Update config from environment variables |
| 72 | + _config_instance.update_config_from_env() |
| 73 | + return _config_instance |
| 74 | + |
| 75 | + |
| 76 | +def hex_hash_to_int16(s: str) -> int: |
| 77 | + """ |
| 78 | + Convert a hex hash string to a 16-bit integer. |
| 79 | + """ |
| 80 | + return int(s, 16) & 0xFFFF |
| 81 | + |
| 82 | + |
| 83 | +def apply_mm_hashes_to_token_ids( |
| 84 | + token_ids: torch.Tensor, |
| 85 | + mm_hashes: list[str], |
| 86 | + mm_positions: list["PlaceholderRange"], |
| 87 | +) -> torch.Tensor: |
| 88 | + """ |
| 89 | + Overwrite token_ids in-place for multimodal placeholders using |
| 90 | + efficient slice assignments. |
| 91 | + """ |
| 92 | + n = token_ids.size(0) |
| 93 | + for hash_str, placeholder in zip(mm_hashes, mm_positions): |
| 94 | + start, length = placeholder.offset, placeholder.length |
| 95 | + if start >= n: |
| 96 | + continue |
| 97 | + end = min(start + length, n) |
| 98 | + token_ids[start:end] = hex_hash_to_int16(hash_str) |
| 99 | + return token_ids |
| 100 | + |
| 101 | + |
| 102 | +def mla_enabled(model_config: "ModelConfig") -> bool: |
| 103 | + return ( |
| 104 | + hasattr(model_config, "use_mla") |
| 105 | + and isinstance(model_config.use_mla, bool) |
| 106 | + and model_config.use_mla |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def create_lmcache_metadata( |
| 111 | + vllm_config=None, model_config=None, parallel_config=None, cache_config=None |
| 112 | +): |
| 113 | + """ |
| 114 | + Create LMCacheEngineMetadata from vLLM configuration. |
| 115 | +
|
| 116 | + This function extracts common metadata creation logic that was duplicated |
| 117 | + across multiple files. |
| 118 | +
|
| 119 | + Args: |
| 120 | + vllm_config (VllmConfig): vLLM configuration object containing model, |
| 121 | + parallel, and cache configs (alternative to |
| 122 | + individual config parameters) |
| 123 | + model_config (ModelConfig): Model configuration (alternative to |
| 124 | + vllm_config) |
| 125 | + parallel_config (ParallelConfig): Parallel configuration (alternative |
| 126 | + to vllm_config) |
| 127 | + cache_config (CacheConfig): Cache configuration (alternative to |
| 128 | + vllm_config) |
| 129 | + """ |
| 130 | + # Third Party |
| 131 | + # First Party |
| 132 | + from lmcache.config import LMCacheEngineMetadata |
| 133 | + |
| 134 | + from vllm.utils import get_kv_cache_torch_dtype |
| 135 | + |
| 136 | + config = lmcache_get_or_create_config() |
| 137 | + # Support both vllm_config object and individual config parameters |
| 138 | + if vllm_config is not None: |
| 139 | + model_cfg = vllm_config.model_config |
| 140 | + parallel_cfg = vllm_config.parallel_config |
| 141 | + cache_cfg = vllm_config.cache_config |
| 142 | + else: |
| 143 | + if model_config is None or parallel_config is None or cache_config is None: |
| 144 | + raise ValueError( |
| 145 | + "Either vllm_config must be provided, or all of " |
| 146 | + "model_config, parallel_config, and cache_config must be provided." |
| 147 | + ) |
| 148 | + model_cfg = model_config |
| 149 | + parallel_cfg = parallel_config |
| 150 | + cache_cfg = cache_config |
| 151 | + |
| 152 | + # Get KV cache dtype |
| 153 | + kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype) |
| 154 | + |
| 155 | + # Check if MLA is enabled |
| 156 | + use_mla = mla_enabled(model_cfg) |
| 157 | + |
| 158 | + # Construct KV shape (for memory pool) |
| 159 | + num_layer = model_cfg.get_num_layers(parallel_cfg) |
| 160 | + chunk_size = config.chunk_size |
| 161 | + num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg) |
| 162 | + head_size = model_cfg.get_head_size() |
| 163 | + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) |
| 164 | + |
| 165 | + # Create metadata |
| 166 | + metadata = LMCacheEngineMetadata( |
| 167 | + model_cfg.model, |
| 168 | + parallel_cfg.world_size, |
| 169 | + parallel_cfg.rank, |
| 170 | + "vllm", |
| 171 | + kv_dtype, |
| 172 | + kv_shape, |
| 173 | + use_mla, |
| 174 | + ) |
| 175 | + |
| 176 | + return metadata, config |
| 177 | + |
| 178 | + |
| 179 | +def extract_mm_features( |
| 180 | + request: Union["Request", "NewRequestData"], modify: bool = False |
| 181 | +) -> tuple[list[str], list["PlaceholderRange"]]: |
| 182 | + """ |
| 183 | + Normalize multimodal information from a Request into parallel lists. |
| 184 | +
|
| 185 | + This helper reads either: |
| 186 | + 1) `request.mm_features` (objects each exposing `.identifier` and |
| 187 | + `.mm_position`), or |
| 188 | + 2) legacy fields `request.mm_hashes` and `request.mm_positions`. |
| 189 | +
|
| 190 | + It returns two equally sized lists: the multimodal hash identifiers and |
| 191 | + their corresponding positions. If the request contains no multimodal info, |
| 192 | + it returns `([], [])`. |
| 193 | +
|
| 194 | + Args: |
| 195 | + request (Request): The source object. |
| 196 | + modify (bool): |
| 197 | + Controls copy semantics for the legacy-path return values. |
| 198 | + - If True and legacy fields are used, shallow-copies are returned so |
| 199 | + the caller can mutate the lists without affecting `request`. |
| 200 | + - If False, the original legacy sequences are returned as-is |
| 201 | + (zero-copy); treat them as read-only. |
| 202 | +
|
| 203 | + Returns: |
| 204 | + tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`). |
| 205 | + May be `([], [])` when no multimodal data is present. |
| 206 | + """ |
| 207 | + if getattr(request, "mm_features", None): |
| 208 | + mm_hashes, mm_positions = zip( |
| 209 | + *((f.identifier, f.mm_position) for f in request.mm_features) |
| 210 | + ) |
| 211 | + return (list(mm_hashes), list(mm_positions)) |
| 212 | + elif getattr(request, "mm_hashes", None): |
| 213 | + if modify: |
| 214 | + return ( |
| 215 | + request.mm_hashes.copy(), # type: ignore |
| 216 | + request.mm_positions.copy(), # type: ignore |
| 217 | + ) |
| 218 | + else: |
| 219 | + return (request.mm_hashes, request.mm_positions) # type: ignore |
| 220 | + else: |
| 221 | + return ([], []) |
0 commit comments