Skip to content

Commit 83f478b

Browse files
authored
[KVConnector] Migrate the LMCache integration code to be vLLM native (#25542)
Signed-off-by: ApostaC <yihua98@uchicago.edu>
1 parent 269c4db commit 83f478b

File tree

4 files changed

+1637
-2
lines changed

4 files changed

+1637
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
from typing import TYPE_CHECKING, Any
44

55
import torch
6-
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
6+
from lmcache.integration.vllm.vllm_v1_adapter import (
7+
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
8+
)
79

810
from vllm.config import VllmConfig
911
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1012
KVConnectorBase_V1,
1113
KVConnectorMetadata,
1214
KVConnectorRole,
1315
)
16+
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
17+
vllm_v1_adapter as _adapter,
18+
)
1419
from vllm.logger import init_logger
1520
from vllm.v1.core.sched.output import SchedulerOutput
1621

@@ -26,7 +31,18 @@
2631
class LMCacheConnectorV1(KVConnectorBase_V1):
2732
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
2833
super().__init__(vllm_config=vllm_config, role=role)
29-
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
34+
assert vllm_config.kv_transfer_config is not None
35+
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
36+
"use_native", False
37+
)
38+
if use_native:
39+
logger.info("Initializing native LMCache connector")
40+
cls = _adapter.LMCacheConnectorV1Impl
41+
else:
42+
logger.info("Initializing latest dev LMCache connector")
43+
cls = LMCacheConnectorLatestImpl
44+
45+
self._lmcache_engine = cls(vllm_config, role, self)
3046

3147
# ==============================
3248
# Worker-side methods
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Standard
4+
import os
5+
import threading
6+
from typing import TYPE_CHECKING, Union
7+
8+
import torch
9+
from lmcache.config import LMCacheEngineConfig as Config
10+
from lmcache.logging import init_logger
11+
from lmcache.v1.config import LMCacheEngineConfig as V1Config
12+
13+
if TYPE_CHECKING:
14+
from vllm.config import ModelConfig
15+
from vllm.multimodal.inputs import PlaceholderRange
16+
from vllm.v1.core.sched.output import NewRequestData
17+
from vllm.v1.request import Request
18+
19+
logger = init_logger(__name__)
20+
ENGINE_NAME = "vllm-instance"
21+
22+
# Thread-safe singleton storage
23+
_config_instance: Config | V1Config | None = None
24+
_config_lock = threading.Lock()
25+
26+
27+
def is_false(value: str) -> bool:
28+
"""Check if the given string value is equivalent to 'false'."""
29+
return value.lower() in ("false", "0", "no", "n", "off")
30+
31+
32+
def lmcache_get_or_create_config() -> Config | V1Config:
33+
"""Get the LMCache configuration from the environment variable
34+
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
35+
function will return the default configuration.
36+
37+
This function is thread-safe and implements singleton pattern,
38+
ensuring the configuration is loaded only once.
39+
"""
40+
global _config_instance
41+
42+
# Double-checked locking for thread-safe singleton
43+
if _config_instance is None:
44+
with _config_lock:
45+
if _config_instance is None: # Check again within lock
46+
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
47+
logger.warning(
48+
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
49+
"Using legacy configuration is deprecated and will "
50+
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
51+
"to True."
52+
)
53+
LMCacheEngineConfig = Config # type: ignore[assignment]
54+
else:
55+
LMCacheEngineConfig = V1Config # type: ignore[assignment]
56+
57+
if "LMCACHE_CONFIG_FILE" not in os.environ:
58+
logger.warning(
59+
"No LMCache configuration file is set. Trying to read"
60+
" configurations from the environment variables."
61+
)
62+
logger.warning(
63+
"You can set the configuration file through "
64+
"the environment variable: LMCACHE_CONFIG_FILE"
65+
)
66+
_config_instance = LMCacheEngineConfig.from_env()
67+
else:
68+
config_file = os.environ["LMCACHE_CONFIG_FILE"]
69+
logger.info("Loading LMCache config file %s", config_file)
70+
_config_instance = LMCacheEngineConfig.from_file(config_file)
71+
# Update config from environment variables
72+
_config_instance.update_config_from_env()
73+
return _config_instance
74+
75+
76+
def hex_hash_to_int16(s: str) -> int:
77+
"""
78+
Convert a hex hash string to a 16-bit integer.
79+
"""
80+
return int(s, 16) & 0xFFFF
81+
82+
83+
def apply_mm_hashes_to_token_ids(
84+
token_ids: torch.Tensor,
85+
mm_hashes: list[str],
86+
mm_positions: list["PlaceholderRange"],
87+
) -> torch.Tensor:
88+
"""
89+
Overwrite token_ids in-place for multimodal placeholders using
90+
efficient slice assignments.
91+
"""
92+
n = token_ids.size(0)
93+
for hash_str, placeholder in zip(mm_hashes, mm_positions):
94+
start, length = placeholder.offset, placeholder.length
95+
if start >= n:
96+
continue
97+
end = min(start + length, n)
98+
token_ids[start:end] = hex_hash_to_int16(hash_str)
99+
return token_ids
100+
101+
102+
def mla_enabled(model_config: "ModelConfig") -> bool:
103+
return (
104+
hasattr(model_config, "use_mla")
105+
and isinstance(model_config.use_mla, bool)
106+
and model_config.use_mla
107+
)
108+
109+
110+
def create_lmcache_metadata(
111+
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
112+
):
113+
"""
114+
Create LMCacheEngineMetadata from vLLM configuration.
115+
116+
This function extracts common metadata creation logic that was duplicated
117+
across multiple files.
118+
119+
Args:
120+
vllm_config (VllmConfig): vLLM configuration object containing model,
121+
parallel, and cache configs (alternative to
122+
individual config parameters)
123+
model_config (ModelConfig): Model configuration (alternative to
124+
vllm_config)
125+
parallel_config (ParallelConfig): Parallel configuration (alternative
126+
to vllm_config)
127+
cache_config (CacheConfig): Cache configuration (alternative to
128+
vllm_config)
129+
"""
130+
# Third Party
131+
# First Party
132+
from lmcache.config import LMCacheEngineMetadata
133+
134+
from vllm.utils import get_kv_cache_torch_dtype
135+
136+
config = lmcache_get_or_create_config()
137+
# Support both vllm_config object and individual config parameters
138+
if vllm_config is not None:
139+
model_cfg = vllm_config.model_config
140+
parallel_cfg = vllm_config.parallel_config
141+
cache_cfg = vllm_config.cache_config
142+
else:
143+
if model_config is None or parallel_config is None or cache_config is None:
144+
raise ValueError(
145+
"Either vllm_config must be provided, or all of "
146+
"model_config, parallel_config, and cache_config must be provided."
147+
)
148+
model_cfg = model_config
149+
parallel_cfg = parallel_config
150+
cache_cfg = cache_config
151+
152+
# Get KV cache dtype
153+
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
154+
155+
# Check if MLA is enabled
156+
use_mla = mla_enabled(model_cfg)
157+
158+
# Construct KV shape (for memory pool)
159+
num_layer = model_cfg.get_num_layers(parallel_cfg)
160+
chunk_size = config.chunk_size
161+
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
162+
head_size = model_cfg.get_head_size()
163+
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
164+
165+
# Create metadata
166+
metadata = LMCacheEngineMetadata(
167+
model_cfg.model,
168+
parallel_cfg.world_size,
169+
parallel_cfg.rank,
170+
"vllm",
171+
kv_dtype,
172+
kv_shape,
173+
use_mla,
174+
)
175+
176+
return metadata, config
177+
178+
179+
def extract_mm_features(
180+
request: Union["Request", "NewRequestData"], modify: bool = False
181+
) -> tuple[list[str], list["PlaceholderRange"]]:
182+
"""
183+
Normalize multimodal information from a Request into parallel lists.
184+
185+
This helper reads either:
186+
1) `request.mm_features` (objects each exposing `.identifier` and
187+
`.mm_position`), or
188+
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
189+
190+
It returns two equally sized lists: the multimodal hash identifiers and
191+
their corresponding positions. If the request contains no multimodal info,
192+
it returns `([], [])`.
193+
194+
Args:
195+
request (Request): The source object.
196+
modify (bool):
197+
Controls copy semantics for the legacy-path return values.
198+
- If True and legacy fields are used, shallow-copies are returned so
199+
the caller can mutate the lists without affecting `request`.
200+
- If False, the original legacy sequences are returned as-is
201+
(zero-copy); treat them as read-only.
202+
203+
Returns:
204+
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
205+
May be `([], [])` when no multimodal data is present.
206+
"""
207+
if getattr(request, "mm_features", None):
208+
mm_hashes, mm_positions = zip(
209+
*((f.identifier, f.mm_position) for f in request.mm_features)
210+
)
211+
return (list(mm_hashes), list(mm_positions))
212+
elif getattr(request, "mm_hashes", None):
213+
if modify:
214+
return (
215+
request.mm_hashes.copy(), # type: ignore
216+
request.mm_positions.copy(), # type: ignore
217+
)
218+
else:
219+
return (request.mm_hashes, request.mm_positions) # type: ignore
220+
else:
221+
return ([], [])

0 commit comments

Comments
 (0)