Skip to content

Commit d1f8b3b

Browse files
xuechendicharlifu
authored andcommitted
[NIXL][OOT platform] support nixl_connector with oot platform and other nixl_backend (vllm-project#25121)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent 31259d2 commit d1f8b3b

File tree

5 files changed

+99
-9
lines changed

5 files changed

+99
-9
lines changed

docs/features/disagg_prefill.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ Now supports 5 types of connectors:
3131
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
3232
```
3333

34+
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
35+
36+
```bash
37+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
38+
```
39+
3440
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
3541

3642
```bash

docs/serving/expert_parallel_deployment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
193193

194194
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
195195

196-
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
196+
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
197197

198198
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
199199

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
2828
NixlConnectorWorker, NixlKVConnectorStats)
2929
from vllm.forward_context import ForwardContext
30+
from vllm.platforms.interface import Platform
3031
from vllm.sampling_params import SamplingParams
3132
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
3233
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@@ -56,7 +57,7 @@ def __init__(self, agent_name: str, *args, **kwargs):
5657
def get_reg_descs(self, caches_data, memory_type: str) -> list:
5758
return [str(uuid.uuid4()) for _ in caches_data]
5859

59-
def register_memory(self, descs) -> None:
60+
def register_memory(self, descs, backends) -> None:
6061
pass
6162

6263
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
@@ -855,3 +856,52 @@ def test_register_kv_caches(dist_init):
855856
assert block_len == expected_block_len, \
856857
f"Block entry {i}: Expected block len {expected_block_len}, " \
857858
f"got {block_len}"
859+
860+
861+
class FakePlatform(Platform):
862+
device_type: str = "oot"
863+
864+
@classmethod
865+
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
866+
"""
867+
Returns a mapping from device_type to a tuple of supported
868+
kv_buffer_device for nixl.
869+
"""
870+
return {'oot': ('oot', )}
871+
872+
@classmethod
873+
def get_nixl_memory_type(cls) -> Optional[str]:
874+
"""
875+
Returns the nixl memory type for the current platform.
876+
"""
877+
return 'VRAM'
878+
879+
880+
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
881+
("oot", "VRAM"),
882+
])
883+
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
884+
nixl_memory_type):
885+
"""
886+
Test that register_kv_caches() passes the correct memory types from the
887+
config to the nixl_wrapper.
888+
"""
889+
vllm_config = create_vllm_config()
890+
# Override the default memory types in the config
891+
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
892+
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
893+
_NIXL_SUPPORTED_DEVICE)
894+
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
895+
896+
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
897+
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
898+
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
899+
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
900+
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
901+
902+
# Create connector and replace its worker with a fake one for isolation
903+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
904+
905+
# Verify get_reg_descs was called with the correct memory_type
906+
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
907+
assert connector.connector_worker.nixl_memory_type == nixl_memory_type

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,21 @@
5858
logger.warning("NIXL is not available")
5959
NixlWrapper = None
6060

61+
try:
62+
from nixl._api import nixl_agent_config
63+
except ImportError:
64+
nixl_agent_config = None
65+
logger.warning("NIXL agent config is not available")
66+
6167
# Supported platforms and types of kv transfer buffer.
6268
# {device: tuple of supported kv buffer types}
6369
_NIXL_SUPPORTED_DEVICE = {
6470
"cuda": ("cuda", ),
6571
"tpu": ("cpu", ),
6672
"xpu": ("cpu", ),
6773
}
74+
# support for oot platform by providing mapping in current_platform
75+
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
6876

6977

7078
class NixlAgentMetadata(
@@ -448,8 +456,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
448456
self.vllm_config = vllm_config
449457
self.block_size = vllm_config.cache_config.block_size
450458

459+
self.nixl_backends = \
460+
vllm_config.kv_transfer_config.get_from_extra_config(
461+
"backends", ["UCX"])
451462
# Agent.
452-
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
463+
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
464+
config = nixl_agent_config(backends=self.nixl_backends) if len(
465+
non_ucx_backends) > 0 and nixl_agent_config is not None else None
466+
467+
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
453468
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
454469
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
455470

@@ -486,11 +501,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
486501
# used when device memory can not be registered under nixl
487502
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
488503
self.use_host_buffer = self.kv_buffer_device == "cpu"
489-
if self.kv_buffer_device == "cuda":
490-
self.nixl_memory_type = "VRAM"
491-
elif self.kv_buffer_device == "cpu":
492-
self.nixl_memory_type = "DRAM"
493-
else:
504+
# support for oot platform which can't register nixl memory
505+
# type based on kv_buffer_device
506+
self.nixl_memory_type = current_platform.get_nixl_memory_type()
507+
if self.nixl_memory_type is None:
508+
if self.kv_buffer_device == "cuda":
509+
self.nixl_memory_type = "VRAM"
510+
elif self.kv_buffer_device == "cpu":
511+
self.nixl_memory_type = "DRAM"
512+
if self.nixl_memory_type is None:
494513
raise RuntimeError(
495514
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
496515
"is not supported.")
@@ -766,7 +785,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
766785
descs = self.nixl_wrapper.get_reg_descs(caches_data,
767786
self.nixl_memory_type)
768787
logger.debug("Registering descs: %s", caches_data)
769-
self.nixl_wrapper.register_memory(descs)
788+
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
770789
logger.debug("Done registering descs")
771790
self._registered_descs.append(descs)
772791

vllm/platforms/interface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,21 @@ def _synced_weight_loader(param, *args, **kwargs):
604604

605605
return _synced_weight_loader
606606

607+
@classmethod
608+
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
609+
"""
610+
Returns a mapping from device_type to a tuple of supported
611+
kv_buffer_device for nixl.
612+
"""
613+
return {}
614+
615+
@classmethod
616+
def get_nixl_memory_type(cls) -> Optional[str]:
617+
"""
618+
Returns the nixl memory type for the current platform.
619+
"""
620+
return None
621+
607622

608623
class UnspecifiedPlatform(Platform):
609624
_enum = PlatformEnum.UNSPECIFIED

0 commit comments

Comments
 (0)