Skip to content

Commit 15ec91b

Browse files
committed
[KV Connector] Make KVCacheConfig an explicit constructor argument
Follow on from vllm-project#25712 `VllmConfig` is explicitly designed as a dataclass containing user-provided configuration and model metadata. It is a global configuration object that lives throughout the entire engine lifetime and is meant to be immutable after `__post_init__()`. `KVCacheConfig` is worker-specific, runtime-computed state. It has limited lifetime, and its purpose is limited to initializing the KV Cache in the model runner. Even if we add KV cache hints to model config.json in future, this would be parsed into `ModelConfig`, used as input to the `get_kv_cache_configs()` computation, and the resulting `KVCacheConfig` would still be runtime state. We are currently creating per-worker copies of VllmConfig in order to attach the runtime `KVCacheConfig` state. But instead we should just explicitly pass `KVCacheConfig` to the connector. Make sure to handle backwards compatibility for external connector implementations (loaded via module path) that have the old style constructor signature. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 3857eb8 commit 15ec91b

File tree

14 files changed

+408
-41
lines changed

14 files changed

+408
-41
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Unit tests for backwards compatibility with external KV connector implementations.
5+
6+
This test ensures that external connectors (loaded via kv_connector_module_path)
7+
implemented with the old signature continue to work:
8+
- Old signature: __init__(self, vllm_config, role)
9+
- New signature: __init__(self, vllm_config, role, kv_cache_config)
10+
"""
11+
12+
from typing import TYPE_CHECKING
13+
from unittest.mock import patch
14+
15+
import pytest
16+
17+
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
18+
from vllm.distributed.kv_transfer.kv_connector.v1 import (
19+
KVConnectorBase_V1,
20+
KVConnectorRole,
21+
)
22+
from vllm.v1.core.sched.output import SchedulerOutput
23+
24+
from .utils import create_scheduler, create_vllm_config
25+
26+
if TYPE_CHECKING:
27+
from vllm.attention.backends.abstract import AttentionMetadata
28+
from vllm.config import VllmConfig
29+
from vllm.forward_context import ForwardContext
30+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
31+
from vllm.v1.kv_cache_interface import KVCacheConfig
32+
from vllm.v1.request import Request
33+
34+
35+
class OldStyleTestConnector(KVConnectorBase_V1):
36+
"""
37+
Test connector using the old signature with 2 required arguments.
38+
This simulates external connectors that haven't been updated yet.
39+
"""
40+
41+
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
42+
# Old-style call to super().__init__ with only 2 arguments
43+
super().__init__(vllm_config=vllm_config, role=role)
44+
45+
def get_num_new_matched_tokens(
46+
self, request: "Request", num_computed_tokens: int
47+
) -> tuple[int | None, bool]:
48+
return 0, False
49+
50+
def update_state_after_alloc(
51+
self,
52+
request: "Request",
53+
blocks: "KVCacheBlocks",
54+
num_external_tokens: int,
55+
):
56+
pass
57+
58+
def build_connector_meta(self, scheduler_output: SchedulerOutput):
59+
return None
60+
61+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
62+
pass
63+
64+
def wait_for_layer_load(self, layer_name: str) -> None:
65+
pass
66+
67+
def save_kv_layer(
68+
self,
69+
layer_name: str,
70+
kv_layer,
71+
attn_metadata: "AttentionMetadata",
72+
**kwargs,
73+
) -> None:
74+
pass
75+
76+
def wait_for_save(self):
77+
pass
78+
79+
80+
class NewStyleTestConnector(KVConnectorBase_V1):
81+
"""
82+
Test connector using the new signature with 3 required arguments.
83+
"""
84+
85+
def __init__(
86+
self,
87+
vllm_config: "VllmConfig",
88+
role: KVConnectorRole,
89+
kv_cache_config: "KVCacheConfig",
90+
):
91+
# New-style call to super().__init__ with all 3 arguments
92+
super().__init__(
93+
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
94+
)
95+
96+
def get_num_new_matched_tokens(
97+
self, request: "Request", num_computed_tokens: int
98+
) -> tuple[int | None, bool]:
99+
return 0, False
100+
101+
def update_state_after_alloc(
102+
self,
103+
request: "Request",
104+
blocks: "KVCacheBlocks",
105+
num_external_tokens: int,
106+
):
107+
pass
108+
109+
def build_connector_meta(self, scheduler_output: SchedulerOutput):
110+
return None
111+
112+
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
113+
pass
114+
115+
def wait_for_layer_load(self, layer_name: str) -> None:
116+
pass
117+
118+
def save_kv_layer(
119+
self,
120+
layer_name: str,
121+
kv_layer,
122+
attn_metadata: "AttentionMetadata",
123+
**kwargs,
124+
) -> None:
125+
pass
126+
127+
def wait_for_save(self):
128+
pass
129+
130+
131+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
132+
def test_external_old_signature_factory_instantiation(role):
133+
"""
134+
Test that external connectors with old signature (2 required args) loaded
135+
via kv_connector_module_path are correctly instantiated with backwards
136+
compatibility support.
137+
"""
138+
vllm_config = create_vllm_config()
139+
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
140+
vllm_config.kv_transfer_config.kv_connector_module_path = (
141+
"tests.v1.kv_connector.unit.test_backwards_compatibility"
142+
)
143+
144+
scheduler = create_scheduler(vllm_config)
145+
kv_cache_config = scheduler.kv_cache_config
146+
147+
connector = KVConnectorFactory.create_connector(vllm_config, kv_cache_config, role)
148+
149+
assert connector is not None
150+
assert isinstance(connector, OldStyleTestConnector)
151+
assert connector.role == role
152+
assert connector._kv_cache_config is None
153+
154+
155+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
156+
def test_external_new_signature_factory_instantiation(role):
157+
"""
158+
Test that external connectors with new signature (3 required args) loaded
159+
via kv_connector_module_path are correctly instantiated.
160+
"""
161+
vllm_config = create_vllm_config()
162+
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
163+
vllm_config.kv_transfer_config.kv_connector_module_path = (
164+
"tests.v1.kv_connector.unit.test_backwards_compatibility"
165+
)
166+
167+
scheduler = create_scheduler(vllm_config)
168+
kv_cache_config = scheduler.kv_cache_config
169+
170+
connector = KVConnectorFactory.create_connector(vllm_config, kv_cache_config, role)
171+
172+
assert connector is not None
173+
assert isinstance(connector, NewStyleTestConnector)
174+
assert connector.role == role
175+
assert connector._kv_cache_config is not None
176+
assert connector._kv_cache_config == kv_cache_config
177+
178+
179+
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
180+
def test_old_signature_super_init(role):
181+
"""
182+
Test that old-style connectors can call super().__init__() without
183+
kv_cache_config parameter.
184+
"""
185+
vllm_config = create_vllm_config()
186+
187+
connector = OldStyleTestConnector(vllm_config, role)
188+
189+
assert connector is not None
190+
assert connector.role == role
191+
assert connector._kv_cache_config is None
192+
193+
194+
def test_old_signature_super_init_with_kwargs():
195+
"""
196+
Test that old-style connectors can call super().__init__() with keyword
197+
arguments in different orders.
198+
"""
199+
vllm_config = create_vllm_config()
200+
201+
# Test with vllm_config= and role= kwargs
202+
connector1 = OldStyleTestConnector(
203+
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
204+
)
205+
assert connector1 is not None
206+
assert connector1._kv_cache_config is None
207+
208+
# Test with role= and vllm_config= in reversed order
209+
connector2 = OldStyleTestConnector(
210+
role=KVConnectorRole.WORKER, vllm_config=vllm_config
211+
)
212+
assert connector2 is not None
213+
assert connector2._kv_cache_config is None
214+
215+
216+
def test_internal_connector_uses_new_signature():
217+
"""
218+
Test that internal connectors (registered in factory) always use the new
219+
signature and get kv_cache_config.
220+
"""
221+
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
222+
SharedStorageConnector,
223+
)
224+
225+
vllm_config = create_vllm_config()
226+
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
227+
228+
scheduler = create_scheduler(vllm_config)
229+
kv_cache_config = scheduler.kv_cache_config
230+
231+
connector = KVConnectorFactory.create_connector(
232+
vllm_config, kv_cache_config, KVConnectorRole.SCHEDULER
233+
)
234+
235+
assert connector is not None
236+
assert isinstance(connector, SharedStorageConnector)
237+
assert connector._kv_cache_config is not None
238+
assert connector._kv_cache_config == kv_cache_config
239+
240+
241+
def test_signature_detection_with_mocking():
242+
"""
243+
Test that the factory correctly applies compat_sig flag returned from
244+
_get_connector_class_with_compat.
245+
"""
246+
vllm_config = create_vllm_config()
247+
scheduler = create_scheduler(vllm_config)
248+
kv_cache_config = scheduler.kv_cache_config
249+
250+
# Mock _get_connector_class_with_compat to return old-style connector
251+
with patch.object(
252+
KVConnectorFactory,
253+
"_get_connector_class_with_compat",
254+
return_value=(OldStyleTestConnector, True),
255+
):
256+
old_connector = KVConnectorFactory.create_connector(
257+
vllm_config, kv_cache_config, KVConnectorRole.SCHEDULER
258+
)
259+
assert old_connector is not None
260+
assert isinstance(old_connector, OldStyleTestConnector)
261+
assert old_connector._kv_cache_config is None
262+
263+
# Mock _get_connector_class_with_compat to return new-style connector
264+
with patch.object(
265+
KVConnectorFactory,
266+
"_get_connector_class_with_compat",
267+
return_value=(NewStyleTestConnector, False),
268+
):
269+
new_connector = KVConnectorFactory.create_connector(
270+
vllm_config, kv_cache_config, KVConnectorRole.SCHEDULER
271+
)
272+
assert new_connector is not None
273+
assert isinstance(new_connector, NewStyleTestConnector)
274+
assert new_connector._kv_cache_config is not None
275+
assert new_connector._kv_cache_config == kv_cache_config

tests/v1/kv_connector/unit/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def create_model_runner_output(
254254

255255

256256
class TestSharedStorageConnector(SharedStorageConnector):
257-
def __init__(self, config: VllmConfig, role):
257+
def __init__(self, config: VllmConfig, role, kv_cache_config):
258258
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
259259
self._connector = SharedStorageConnector(config, role)
260260
self.call_record: dict[str, int] = defaultdict(int)

vllm/distributed/kv_transfer/kv_connector/factory.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING, cast
77

88
import vllm.envs as envs
9-
from vllm.config import VllmConfig
109
from vllm.distributed.kv_transfer.kv_connector.base import (
1110
KVConnectorBase,
1211
KVConnectorBaseType,
@@ -16,9 +15,12 @@
1615
supports_hma,
1716
)
1817
from vllm.logger import init_logger
18+
from vllm.utils.func_utils import supports_kw
1919

2020
if TYPE_CHECKING:
21+
from vllm.config import VllmConfig
2122
from vllm.config.kv_transfer import KVTransferConfig
23+
from vllm.v1.kv_cache_interface import KVCacheConfig
2224

2325
logger = init_logger(__name__)
2426

@@ -41,7 +43,8 @@ def loader() -> type[KVConnectorBase]:
4143
@classmethod
4244
def create_connector(
4345
cls,
44-
config: VllmConfig,
46+
config: "VllmConfig",
47+
kv_cache_config: "KVCacheConfig",
4548
role: KVConnectorRole,
4649
) -> KVConnectorBase:
4750
if not envs.VLLM_USE_V1:
@@ -53,7 +56,9 @@ def create_connector(
5356
kv_transfer_config = config.kv_transfer_config
5457
if kv_transfer_config is None:
5558
raise ValueError("kv_transfer_config must be set to create a connector")
56-
connector_cls = cls.get_connector_class(kv_transfer_config)
59+
connector_cls, compat_sig = cls._get_connector_class_with_compat(
60+
kv_transfer_config
61+
)
5762

5863
# check if the connector supports HMA
5964
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
@@ -76,7 +81,12 @@ def create_connector(
7681
# - Co-locate with worker process
7782
# - Should only be used inside the forward context & attention layer
7883
# We build separately to enforce strict separation
79-
return connector_cls(config, role)
84+
if compat_sig:
85+
# Old signature: __init__(self, vllm_config, role)
86+
return connector_cls(config, role)
87+
else:
88+
# New signature: __init__(self, vllm_config, role, kv_cache_config)
89+
return connector_cls(config, role, kv_cache_config)
8090

8191
@classmethod
8292
def get_connector_class_by_name(
@@ -97,13 +107,13 @@ def get_connector_class_by_name(
97107
return cls._registry[connector_name]()
98108

99109
@classmethod
100-
def get_connector_class(
110+
def _get_connector_class_with_compat(
101111
cls, kv_transfer_config: "KVTransferConfig"
102-
) -> type[KVConnectorBaseType]:
103-
"""Get the connector class by name."""
112+
) -> tuple[type[KVConnectorBaseType], bool]:
104113
connector_name = kv_transfer_config.kv_connector
105114
if connector_name is None:
106115
raise ValueError("Connector name is not set in KVTransferConfig")
116+
compat_sig = False
107117
if connector_name in cls._registry:
108118
connector_cls = cls._registry[connector_name]()
109119
else:
@@ -118,6 +128,21 @@ def get_connector_class(
118128
f"Class {connector_name} not found in {connector_module_path}"
119129
) from e
120130
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
131+
if not supports_kw(connector_cls, "kv_cache_config"):
132+
compat_sig = True
133+
logger.warning(
134+
"Connector %s uses deprecated signature with 2 required arguments. "
135+
"Please update to include kv_cache_config as the second argument.",
136+
connector_cls.__name__,
137+
)
138+
return connector_cls, compat_sig
139+
140+
@classmethod
141+
def get_connector_class(
142+
cls, kv_transfer_config: "KVTransferConfig"
143+
) -> type[KVConnectorBaseType]:
144+
"""Get the connector class by name."""
145+
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
121146
return connector_cls
122147

123148

0 commit comments

Comments
 (0)