Skip to content

Commit 40db194

Browse files
sammshenSamuel ShenApostaC
authored
[CI]: Add LMCacheConnector Unit Tests (#27852)
Signed-off-by: Samuel Shen <slshen@uchciago.edu> Co-authored-by: Samuel Shen <slshen@uchciago.edu> Co-authored-by: Yihua Cheng <yihua98@uchicago.edu>
1 parent c765f0b commit 40db194

File tree

4 files changed

+274
-6
lines changed

4 files changed

+274
-6
lines changed

.buildkite/test-amd.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ steps:
344344
- pytest -v -s v1/logits_processors
345345
- pytest -v -s v1/worker
346346
- pytest -v -s v1/spec_decode
347-
- pytest -v -s -m 'not cpu_test' v1/kv_connector/unit
347+
- pytest -v -s -m 'not cpu_test' v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_lmcache_integration.py
348348
- pytest -v -s -m 'not cpu_test' v1/metrics
349349
- pytest -v -s v1/test_oracle.py
350350
- pytest -v -s v1/test_request.py

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ steps:
316316
- vllm/
317317
- tests/v1
318318
commands:
319+
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
319320
# split the test to avoid interference
320321
- pytest -v -s -m 'not cpu_test' v1/core
321322
- pytest -v -s v1/executor
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# NOTE: if your PR has broken one of the tests here (sorry),
5+
# kindly patch the corresponding integration in
6+
# /vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
7+
# or reach out to @aposataC for assistance
8+
9+
# Assumption vs. Correctness Tests:
10+
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
11+
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
12+
def assumes(obj, attr, is_callable=False, is_instance_of=None):
13+
import inspect
14+
from dataclasses import is_dataclass
15+
16+
assumption_msg = (
17+
f"LMCache connector currently assumes that {obj} has a(n) {attr} attribute"
18+
)
19+
if hasattr(obj, attr):
20+
attr_value = getattr(obj, attr)
21+
elif is_dataclass(obj) and attr in getattr(obj, "__dataclass_fields__", {}):
22+
field = obj.__dataclass_fields__[attr]
23+
field_type = field.type
24+
origin = getattr(field_type, "__origin__", None)
25+
if origin is not None:
26+
field_type = origin
27+
attr_value = field_type
28+
else:
29+
raise AssertionError(assumption_msg)
30+
if is_callable:
31+
assumption_msg += f" and that {obj}.{attr} is a callable"
32+
assert callable(attr_value), assumption_msg
33+
if is_instance_of:
34+
assumption_msg += f" and that {obj}.{attr} is an instance of {is_instance_of}"
35+
if isinstance(attr_value, property):
36+
fget = attr_value.fget
37+
assert fget is not None, f"Property {obj}.{attr} has no fget"
38+
sig = inspect.signature(fget)
39+
ret_anno = sig.return_annotation
40+
assert ret_anno is not inspect._empty, (
41+
f"Property {obj}.{attr} has no return annotation"
42+
)
43+
assert ret_anno == is_instance_of, assumption_msg
44+
else:
45+
if isinstance(attr_value, type):
46+
assert attr_value is is_instance_of, assumption_msg
47+
else:
48+
assert isinstance(attr_value, is_instance_of), assumption_msg
49+
50+
51+
def test_multimodal_interface():
52+
# protect against interface changes
53+
from vllm.multimodal.inputs import PlaceholderRange
54+
55+
assumes(PlaceholderRange, "offset")
56+
assumes(PlaceholderRange, "length")
57+
58+
# test a minimal case
59+
import torch
60+
61+
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
62+
apply_mm_hashes_to_token_ids,
63+
)
64+
65+
token_ids = torch.arange(10, dtype=torch.long)
66+
mm_hashes = ["0000", "1111"] # hex repr of 0 and 4369
67+
mm_positions = [
68+
PlaceholderRange(offset=0, length=4),
69+
PlaceholderRange(offset=5, length=4),
70+
]
71+
apply_mm_hashes_to_token_ids(token_ids, mm_hashes, mm_positions)
72+
assert token_ids.tolist() == [0, 0, 0, 0, 4, 4369, 4369, 4369, 4369, 9]
73+
74+
75+
def test_config_interface():
76+
# protect against interface changes
77+
from vllm.config import VllmConfig
78+
from vllm.config.cache import CacheConfig
79+
from vllm.config.kv_transfer import KVTransferConfig
80+
from vllm.config.model import ModelConfig
81+
from vllm.config.parallel import ParallelConfig
82+
83+
assumes(VllmConfig, "model_config")
84+
assumes(VllmConfig, "cache_config")
85+
assumes(VllmConfig, "parallel_config")
86+
assumes(VllmConfig, "kv_transfer_config")
87+
88+
assumes(KVTransferConfig, "kv_role")
89+
assumes(KVTransferConfig, "kv_connector_extra_config")
90+
91+
assumes(ModelConfig, "use_mla", is_instance_of=bool)
92+
assumes(ModelConfig, "dtype")
93+
assumes(ModelConfig, "max_model_len")
94+
assumes(ModelConfig, "get_vocab_size", is_callable=True)
95+
assumes(ModelConfig, "get_num_attention_heads", is_callable=True)
96+
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
97+
assumes(ModelConfig, "get_head_size", is_callable=True)
98+
assumes(ModelConfig, "get_num_layers", is_callable=True)
99+
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
100+
assumes(ModelConfig, "model")
101+
102+
assumes(ParallelConfig, "world_size")
103+
assumes(ParallelConfig, "rank")
104+
assumes(ParallelConfig, "tensor_parallel_size")
105+
assumes(ParallelConfig, "pipeline_parallel_size")
106+
assumes(ParallelConfig, "data_parallel_size_local")
107+
assumes(ParallelConfig, "data_parallel_rank_local")
108+
109+
assumes(CacheConfig, "cache_dtype")
110+
assumes(CacheConfig, "block_size")
111+
assumes(CacheConfig, "gpu_memory_utilization")
112+
113+
# mla metadata minimal cases
114+
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
115+
mla_enabled,
116+
)
117+
118+
model_config = ModelConfig(model="deepseek-ai/DeepSeek-R1")
119+
assert mla_enabled(model_config)
120+
model_config = ModelConfig(model="Qwen/Qwen3-0.6B")
121+
assert not mla_enabled(model_config)
122+
123+
# kv metadata minimal case
124+
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
125+
126+
model_config = ModelConfig(dtype="bfloat16")
127+
parallel_config = ParallelConfig()
128+
cache_config = CacheConfig(cache_dtype="bfloat16")
129+
kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
130+
use_mla = mla_enabled(model_config)
131+
chunk_size = 256
132+
num_layer = model_config.get_num_layers(parallel_config)
133+
num_kv_head = model_config.get_num_kv_heads(parallel_config)
134+
head_size = model_config.get_head_size()
135+
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
136+
137+
# dummy lmcache metadata creation example
138+
_ = (
139+
model_config.model,
140+
parallel_config.world_size,
141+
parallel_config.rank,
142+
"vllm",
143+
kv_dtype,
144+
kv_shape,
145+
use_mla,
146+
)
147+
148+
149+
def test_request_interface():
150+
# protect against interface changes
151+
from types import NoneType
152+
153+
from vllm.sampling_params import SamplingParams
154+
from vllm.v1.request import Request
155+
156+
req = Request(
157+
request_id="test_request",
158+
prompt_token_ids=[1, 2, 3],
159+
sampling_params=SamplingParams(max_tokens=10),
160+
pooling_params=None,
161+
eos_token_id=100,
162+
lora_request=None,
163+
)
164+
assumes(req, "mm_features", is_instance_of=(list, NoneType))
165+
assumes(req, "request_id")
166+
assumes(req, "priority")
167+
assumes(req, "prompt_token_ids")
168+
assumes(req, "sampling_params")
169+
assumes(req, "num_tokens")
170+
assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType))
171+
172+
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
173+
174+
assumes(MultiModalFeatureSpec, "identifier")
175+
assumes(MultiModalFeatureSpec, "mm_position")
176+
177+
# minimal case:
178+
from vllm.multimodal.inputs import PlaceholderRange
179+
180+
request = Request(
181+
request_id="test_request",
182+
prompt_token_ids=[1, 2, 3],
183+
sampling_params=SamplingParams(max_tokens=10),
184+
pooling_params=None,
185+
eos_token_id=100,
186+
lora_request=None,
187+
mm_features=[
188+
MultiModalFeatureSpec(
189+
modality="image",
190+
identifier="0000",
191+
data=MultiModalKwargsItem.dummy("dummy_m"),
192+
mm_position=PlaceholderRange(offset=0, length=10),
193+
)
194+
],
195+
)
196+
197+
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
198+
extract_mm_features,
199+
)
200+
201+
mm_hashes, mm_positions = extract_mm_features(request)
202+
assert isinstance(mm_hashes, list)
203+
assert len(mm_hashes) == 1
204+
assert isinstance(mm_positions, list)
205+
assert len(mm_positions) == 1
206+
assert mm_positions[0].offset == 0
207+
assert mm_positions[0].length == 10
208+
209+
210+
def test_new_request_interface():
211+
# protect against interface changes
212+
from vllm.v1.core.sched.output import NewRequestData
213+
214+
assumes(NewRequestData, "req_id")
215+
assumes(NewRequestData, "block_ids")
216+
assumes(NewRequestData, "prompt_token_ids")
217+
assumes(NewRequestData, "sampling_params")
218+
219+
220+
def test_sampling_params_interface():
221+
# protect against interface changes
222+
from vllm.sampling_params import SamplingParams
223+
224+
assumes(SamplingParams, "extra_args")
225+
226+
# dumb example use case in LMCache
227+
kv_transfer_params = {
228+
"lmcache.tag.user": "example_user_1",
229+
"lmcache.ttl": 60,
230+
}
231+
sampling_params = SamplingParams(
232+
extra_args={"kv_transfer_params": kv_transfer_params}
233+
)
234+
assert sampling_params.extra_args["kv_transfer_params"] == kv_transfer_params
235+
236+
237+
def test_tp_interface():
238+
# protect against interface changes
239+
import inspect
240+
241+
from vllm.distributed.parallel_state import get_tp_group
242+
243+
sig = inspect.signature(get_tp_group)
244+
GroupCoordinator = sig.return_annotation
245+
246+
assumes(GroupCoordinator, "broadcast", is_callable=True)
247+
assumes(GroupCoordinator, "broadcast_object", is_callable=True)
248+
249+
250+
def test_forward_context_interface():
251+
# protect against interface changes
252+
from vllm.forward_context import ForwardContext
253+
254+
assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
255+
assumes(ForwardContext, "virtual_engine")
256+
assumes(ForwardContext, "attn_metadata")
257+
258+
259+
def test_scheduler_output_interface():
260+
# protect against interface changes
261+
from vllm.v1.core.sched.output import SchedulerOutput
262+
263+
assumes(SchedulerOutput, "finished_req_ids")
264+
assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list)
265+
assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict)
266+
assumes(SchedulerOutput, "scheduled_cached_reqs")
267+
268+
from vllm.v1.core.sched.output import CachedRequestData
269+
270+
assumes(CachedRequestData, "req_ids", is_instance_of=list)
271+
assumes(CachedRequestData, "new_block_ids", is_instance_of=list)

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def get_inference_info(self) -> dict:
724724
"max_model_len": getattr(
725725
vllm_config.model_config, "max_model_len", None
726726
),
727-
"vocab_size": getattr(vllm_config.model_config, "vocab_size", None),
727+
"vocab_size": vllm_config.model_config.get_vocab_size(),
728728
"num_layers": getattr(
729729
vllm_config.model_config, "get_num_layers", lambda _: None
730730
)(vllm_config.parallel_config),
@@ -746,10 +746,6 @@ def get_inference_info(self) -> dict:
746746
"gpu_memory_utilization": getattr(
747747
vllm_config.cache_config, "gpu_memory_utilization", None
748748
),
749-
"swap_space": getattr(vllm_config.cache_config, "swap_space", None),
750-
"enable_prefix_caching": getattr(
751-
vllm_config.cache_config, "enable_prefix_caching", None
752-
),
753749
},
754750
}
755751

0 commit comments

Comments
 (0)