Skip to content

Commit e599e2c

Browse files
[XPU][P/D] Add XPU support in NixlConnector (#22436)
Signed-off-by: zhenwei <zhenwei.liu@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent c29fb54 commit e599e2c

File tree

7 files changed

+114
-71
lines changed

7 files changed

+114
-71
lines changed

requirements/xpu.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ wheel
1010
jinja2>=3.1.6
1111
datasets # for benchmark scripts
1212
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
13+
nixl==0.3.0 # for PD disaggregation
1314
--extra-index-url=https://download.pytorch.org/whl/xpu
1415
torch==2.8.0+xpu
1516
torchaudio

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import defaultdict
77
from collections.abc import Sequence
88
from concurrent.futures import CancelledError, Future
9-
from typing import Optional, cast
9+
from typing import Literal, Optional, Union, cast
1010

1111
import torch
1212

@@ -196,3 +196,51 @@ def callback(fut):
196196
output_future.add_done_callback(make_callback(i))
197197

198198
return result_future
199+
200+
201+
def _make_src_and_dst_indices(
202+
src_block_ids: list[int],
203+
dst_block_ids: list[int],
204+
src_device: Union[torch.device, str],
205+
dst_device: Union[torch.device, str],
206+
) -> tuple[torch.Tensor, torch.Tensor]:
207+
src_indices = torch.tensor(src_block_ids,
208+
device=src_device,
209+
dtype=torch.int64)
210+
dst_indices = torch.tensor(dst_block_ids,
211+
device=dst_device,
212+
dtype=torch.int64)
213+
return src_indices, dst_indices
214+
215+
216+
def copy_kv_blocks(
217+
src_kv_caches: dict[str, torch.Tensor],
218+
dst_kv_caches: dict[str, torch.Tensor],
219+
src_block_ids: list[int],
220+
dst_block_ids: list[int],
221+
direction: Literal["h2d", "d2h"],
222+
) -> None:
223+
"""Copy kv blocks between different buffers."""
224+
if not src_kv_caches or not dst_kv_caches or \
225+
not src_block_ids or not dst_block_ids or \
226+
len(src_block_ids) != len(dst_block_ids):
227+
return
228+
229+
src_device = next(iter(src_kv_caches.values())).device
230+
dst_device = next(iter(dst_kv_caches.values())).device
231+
232+
src_indices, dst_indices = _make_src_and_dst_indices(
233+
src_block_ids=src_block_ids,
234+
dst_block_ids=dst_block_ids,
235+
src_device=src_device,
236+
dst_device=dst_device)
237+
238+
from vllm.platforms import current_platform
239+
if direction == "h2d":
240+
copy_fn = current_platform.insert_blocks_to_device
241+
else:
242+
copy_fn = current_platform.swap_out_blocks_to_host
243+
for layer_name in src_kv_caches:
244+
src_tensor = src_kv_caches[layer_name]
245+
dst_tensor = dst_kv_caches[layer_name]
246+
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
_NIXL_SUPPORTED_XPUS = {
6262
"cuda": ("cuda", ),
6363
"tpu": ("cpu", ),
64+
"xpu": ("cpu", ),
6465
}
6566

6667

vllm/platforms/tpu.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,32 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
200200
model_config: "ModelConfig") -> bool:
201201
return True
202202

203+
@classmethod
204+
@torch.compile(backend="openxla")
205+
def insert_blocks_to_device(
206+
cls,
207+
src_cache: torch.Tensor,
208+
dst_cache: torch.Tensor,
209+
src_block_indices: torch.Tensor,
210+
dst_block_indices: torch.Tensor,
211+
) -> None:
212+
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
213+
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(
214+
dst_cache.device)
215+
216+
@classmethod
217+
@torch.compile(backend="openxla")
218+
def swap_out_blocks_to_host(
219+
cls,
220+
src_cache: torch.Tensor,
221+
dst_cache: torch.Tensor,
222+
src_block_indices: torch.Tensor,
223+
dst_block_indices: torch.Tensor,
224+
) -> None:
225+
""" tpu blocks to cpu blocks"""
226+
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
227+
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
228+
203229

204230
try:
205231
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform

vllm/platforms/xpu.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
164164
vllm_config.scheduler_config.max_model_len,
165165
DEFAULT_MAX_NUM_BATCHED_TOKENS)
166166

167+
if (envs.VLLM_KV_CACHE_LAYOUT is None
168+
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
169+
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
170+
logger.info(
171+
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
172+
"only NHD layout is supported by XPU attention kernels.")
173+
167174
@classmethod
168175
def is_pin_memory_available(cls):
169176
return True
@@ -210,3 +217,27 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
210217
@classmethod
211218
def opaque_attention_op(cls) -> bool:
212219
return True
220+
221+
@classmethod
222+
def insert_blocks_to_device(
223+
cls,
224+
src_cache: torch.Tensor,
225+
dst_cache: torch.Tensor,
226+
src_block_indices: torch.Tensor,
227+
dst_block_indices: torch.Tensor,
228+
) -> None:
229+
"""Copy blocks from src_cache to dst_cache on XPU."""
230+
_src_cache = src_cache[:, src_block_indices]
231+
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
232+
233+
@classmethod
234+
def swap_out_blocks_to_host(
235+
cls,
236+
src_cache: torch.Tensor,
237+
dst_cache: torch.Tensor,
238+
src_block_indices: torch.Tensor,
239+
dst_block_indices: torch.Tensor,
240+
) -> None:
241+
"""Copy blocks from XPU to host (CPU)."""
242+
_src_cache = src_cache[:, src_block_indices]
243+
dst_cache[:, dst_block_indices] = _src_cache.cpu()

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.distributed.eplb.eplb_state import EplbState
2929
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
3030
has_kv_transfer_group)
31+
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
3132
from vllm.distributed.parallel_state import (
3233
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
3334
prepare_communication_buffer_for_model)
@@ -3139,6 +3140,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
31393140

31403141
if has_kv_transfer_group():
31413142
get_kv_transfer_group().register_kv_caches(kv_caches)
3143+
if self.device.type == 'xpu':
3144+
get_kv_transfer_group().set_host_xfer_buffer_ops(
3145+
copy_kv_blocks)
31423146

31433147
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
31443148
"""

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import bisect
44
import gc
55
import time
6-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
6+
from typing import TYPE_CHECKING, Any, Optional, cast
77
from unittest.mock import patch
88

99
import numpy as np
@@ -23,6 +23,7 @@
2323
get_layers_from_vllm_config, update_config)
2424
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2525
has_kv_transfer_group)
26+
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
2627
from vllm.forward_context import set_forward_context
2728
from vllm.logger import init_logger
2829
from vllm.lora.layers import BaseLayerWithLoRA
@@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
18871888
return paddings[index]
18881889

18891890

1890-
def _make_src_and_dst_indices(
1891-
src_block_ids: list[int],
1892-
dst_block_ids: list[int],
1893-
src_device: Union[torch.device, str],
1894-
dst_device: Union[torch.device, str],
1895-
) -> tuple[torch.Tensor, torch.Tensor]:
1896-
src_indices = torch.tensor(src_block_ids,
1897-
device=src_device,
1898-
dtype=torch.int64)
1899-
dst_indices = torch.tensor(dst_block_ids,
1900-
device=dst_device,
1901-
dtype=torch.int64)
1902-
return src_indices, dst_indices
1903-
1904-
1905-
@torch.compile(backend="openxla")
1906-
def _insert_blocks_to_tpu(
1907-
cpu_cache: torch.Tensor,
1908-
tpu_cache: torch.Tensor,
1909-
cpu_block_indices: torch.Tensor,
1910-
tpu_block_indices: torch.Tensor,
1911-
) -> None:
1912-
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
1913-
tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to(
1914-
tpu_cache.device)
1915-
1916-
1917-
@torch.compile(backend="openxla")
1918-
def _swap_out_tpu_blocks(
1919-
tpu_cache: torch.Tensor,
1920-
cpu_cache: torch.Tensor,
1921-
tpu_block_indices: torch.Tensor,
1922-
cpu_block_indices: torch.Tensor,
1923-
) -> None:
1924-
""" tpu blocks to cpu blocks"""
1925-
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
1926-
cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu()
1927-
1928-
1929-
def copy_kv_blocks(
1930-
src_kv_caches: dict[str, torch.Tensor],
1931-
dst_kv_caches: dict[str, torch.Tensor],
1932-
src_block_ids: list[int],
1933-
dst_block_ids: list[int],
1934-
direction: Literal["h2d", "d2h"],
1935-
) -> None:
1936-
"""Copy kv blocks between different buffers."""
1937-
if not src_kv_caches or not dst_kv_caches or \
1938-
not src_block_ids or not dst_block_ids or \
1939-
len(src_block_ids) != len(dst_block_ids):
1940-
return
1941-
1942-
src_device = next(iter(src_kv_caches.values())).device
1943-
dst_device = next(iter(dst_kv_caches.values())).device
1944-
1945-
src_indices, dst_indices = _make_src_and_dst_indices(
1946-
src_block_ids=src_block_ids,
1947-
dst_block_ids=dst_block_ids,
1948-
src_device=src_device,
1949-
dst_device=dst_device)
1950-
1951-
_copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \
1952-
_swap_out_tpu_blocks
1953-
for layer_name in src_kv_caches:
1954-
src_tensor = src_kv_caches[layer_name]
1955-
dst_tensor = dst_kv_caches[layer_name]
1956-
_copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
1957-
1958-
19591891
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
19601892
page_size: int) -> int:
19611893
"""Calculates the padded number of KV cache update slices to avoid

0 commit comments

Comments
 (0)