Skip to content

Commit 56ed760

Browse files
authored
Revert "[Misc] Remove use of CUDA_VISIBLE_DEVICES for device selectio… (#27502)
1 parent 29c9cb8 commit 56ed760

File tree

4 files changed

+7
-35
lines changed

4 files changed

+7
-35
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -991,14 +991,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
991991
# Enable different block lengths for different layers when MLA is used.
992992
self.block_len_per_layer = list[int]()
993993
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
994-
self.device_id = self.tp_rank
995994
for layer_name, cache_or_caches in xfer_buffers.items():
996995
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
997996

998997
for cache in cache_list:
999998
base_addr = cache.data_ptr()
1000-
if not self.use_host_buffer and current_platform.is_cuda_alike():
1001-
self.device_id = cache.device.index
1002999
if base_addr in seen_base_addresses:
10031000
continue
10041001

@@ -1026,7 +1023,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10261023
"All kv cache tensors must have the same size"
10271024
)
10281025
caches_data.append(
1029-
(base_addr, curr_tensor_size_bytes, self.device_id, "")
1026+
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
10301027
)
10311028

10321029
logger.debug(
@@ -1073,7 +1070,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10731070
block_offset = block_id * self.block_len_per_layer[i]
10741071
addr = base_addr + block_offset
10751072
# (addr, len, device id)
1076-
blocks_data.append((addr, kv_block_len, self.device_id))
1073+
blocks_data.append((addr, kv_block_len, self.tp_rank))
10771074

10781075
if self._use_flashinfer:
10791076
# Separate and interleave K/V regions to maintain the same
@@ -1084,13 +1081,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10841081
addr = base_addr + block_offset
10851082
# Register addresses for V cache (K registered first).
10861083
v_addr = addr + kv_block_len
1087-
blocks_data.append((v_addr, kv_block_len, self.device_id))
1084+
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
10881085
logger.debug(
1089-
"Created %s blocks for src engine %s and rank %s on device id %s",
1086+
"Created %s blocks for src engine %s and rank %s",
10901087
len(blocks_data),
10911088
self.engine_id,
10921089
self.tp_rank,
1093-
self.device_id,
10941090
)
10951091

10961092
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

vllm/v1/engine/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,9 @@ def __init__(
134134
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
135135
try:
136136
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
137-
# Adjust device control in DP for non-CUDA platforms
138-
# For CUDA platforms, setting same device id for different DP
139-
# processes affects NCCL init performance.
140137
with (
141138
set_device_control_env_var(vllm_config, local_dp_rank)
142-
if (data_parallel and not current_platform.is_cuda_alike())
139+
if (data_parallel)
143140
else contextlib.nullcontext()
144141
):
145142
proc.start()

vllm/v1/worker/dp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.config import ParallelConfig
99
from vllm.distributed.parallel_state import get_dp_group
1010
from vllm.logger import init_logger
11+
from vllm.platforms import current_platform
1112
from vllm.v1.worker.ubatch_utils import (
1213
UBatchSlices,
1314
check_ubatch_thresholds,
@@ -19,8 +20,7 @@
1920

2021

2122
def _get_device_and_group(parallel_config: ParallelConfig):
22-
# Use the actual device assigned to the DP group, not just the device type
23-
device = get_dp_group().device
23+
device = current_platform.device_type
2424
group = get_dp_group().device_group
2525

2626
# Transfering this tensor from GPU to CPU will introduce a GPU sync

vllm/v1/worker/gpu_worker.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -172,27 +172,6 @@ def init_device(self):
172172
if self.device_config.device.type == "cuda":
173173
# This env var set by Ray causes exceptions with graph building.
174174
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
175-
if (
176-
self.parallel_config.data_parallel_size > 1
177-
and self.parallel_config.data_parallel_size_local > 0
178-
and self.parallel_config.data_parallel_backend != "ray"
179-
):
180-
# Use local DP rank if available, otherwise use global DP rank.
181-
dp_local_rank = self.parallel_config.data_parallel_rank_local
182-
if dp_local_rank is None:
183-
dp_local_rank = self.parallel_config.data_parallel_rank
184-
185-
tp_pp_world_size = (
186-
self.parallel_config.pipeline_parallel_size
187-
* self.parallel_config.tensor_parallel_size
188-
)
189-
190-
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
191-
self.local_rank += dp_local_rank * tp_pp_world_size
192-
assert self.local_rank <= torch.cuda.device_count(), (
193-
f"DP adjusted local rank {self.local_rank} is out of bounds. "
194-
)
195-
196175
self.device = torch.device(f"cuda:{self.local_rank}")
197176
current_platform.set_device(self.device)
198177

0 commit comments

Comments
 (0)