Skip to content

Commit c7d0cb3

Browse files
ilmarkovtlrmchlsmth
authored andcommitted
[Misc] Remove use of CUDA_VISIBLE_DEVICES for device selection (fix DP slow startup time &c) (vllm-project#26709)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent 6a088ca commit c7d0cb3

File tree

4 files changed

+35
-7
lines changed

4 files changed

+35
-7
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -991,11 +991,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
991991
# Enable different block lengths for different layers when MLA is used.
992992
self.block_len_per_layer = list[int]()
993993
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
994+
self.device_id = self.tp_rank
994995
for layer_name, cache_or_caches in xfer_buffers.items():
995996
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
996997

997998
for cache in cache_list:
998999
base_addr = cache.data_ptr()
1000+
if not self.use_host_buffer and current_platform.is_cuda_alike():
1001+
self.device_id = cache.device.index
9991002
if base_addr in seen_base_addresses:
10001003
continue
10011004

@@ -1023,7 +1026,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10231026
"All kv cache tensors must have the same size"
10241027
)
10251028
caches_data.append(
1026-
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
1029+
(base_addr, curr_tensor_size_bytes, self.device_id, "")
10271030
)
10281031

10291032
logger.debug(
@@ -1070,7 +1073,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10701073
block_offset = block_id * self.block_len_per_layer[i]
10711074
addr = base_addr + block_offset
10721075
# (addr, len, device id)
1073-
blocks_data.append((addr, kv_block_len, self.tp_rank))
1076+
blocks_data.append((addr, kv_block_len, self.device_id))
10741077

10751078
if self._use_flashinfer:
10761079
# Separate and interleave K/V regions to maintain the same
@@ -1081,12 +1084,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10811084
addr = base_addr + block_offset
10821085
# Register addresses for V cache (K registered first).
10831086
v_addr = addr + kv_block_len
1084-
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
1087+
blocks_data.append((v_addr, kv_block_len, self.device_id))
10851088
logger.debug(
1086-
"Created %s blocks for src engine %s and rank %s",
1089+
"Created %s blocks for src engine %s and rank %s on device id %s",
10871090
len(blocks_data),
10881091
self.engine_id,
10891092
self.tp_rank,
1093+
self.device_id,
10901094
)
10911095

10921096
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

vllm/v1/engine/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,12 @@ def __init__(
134134
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
135135
try:
136136
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
137+
# Adjust device control in DP for non-CUDA platforms
138+
# For CUDA platforms, setting same device id for different DP
139+
# processes affects NCCL init performance.
137140
with (
138141
set_device_control_env_var(vllm_config, local_dp_rank)
139-
if (data_parallel)
142+
if (data_parallel and not current_platform.is_cuda_alike())
140143
else contextlib.nullcontext()
141144
):
142145
proc.start()

vllm/v1/worker/dp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from vllm.config import ParallelConfig
99
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
1010
from vllm.logger import init_logger
11-
from vllm.platforms import current_platform
1211
from vllm.v1.worker.ubatch_utils import (
1312
UBatchSlices,
1413
check_ubatch_thresholds,
@@ -20,7 +19,8 @@
2019

2120

2221
def _get_device_and_group(parallel_config: ParallelConfig):
23-
device = current_platform.device_type
22+
# Use the actual device assigned to the DP group, not just the device type
23+
device = get_dp_group().device
2424
group = get_dp_group().device_group
2525

2626
# Transfering this tensor from GPU to CPU will introduce a GPU sync

vllm/v1/worker/gpu_worker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,27 @@ def init_device(self):
169169
if self.device_config.device.type == "cuda":
170170
# This env var set by Ray causes exceptions with graph building.
171171
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
172+
if (
173+
self.parallel_config.data_parallel_size > 1
174+
and self.parallel_config.data_parallel_size_local > 0
175+
and self.parallel_config.data_parallel_backend != "ray"
176+
):
177+
# Use local DP rank if available, otherwise use global DP rank.
178+
dp_local_rank = self.parallel_config.data_parallel_rank_local
179+
if dp_local_rank is None:
180+
dp_local_rank = self.parallel_config.data_parallel_rank
181+
182+
tp_pp_world_size = (
183+
self.parallel_config.pipeline_parallel_size
184+
* self.parallel_config.tensor_parallel_size
185+
)
186+
187+
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
188+
self.local_rank += dp_local_rank * tp_pp_world_size
189+
assert self.local_rank <= torch.cuda.device_count(), (
190+
f"DP adjusted local rank {self.local_rank} is out of bounds. "
191+
)
192+
172193
self.device = torch.device(f"cuda:{self.local_rank}")
173194
current_platform.set_device(self.device)
174195

0 commit comments

Comments
 (0)