Skip to content

Commit db1764e

Browse files
authored
[Platform] allow platform to init dp group (#22243)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 7f83b4e commit db1764e

File tree

5 files changed

+15
-83
lines changed

5 files changed

+15
-83
lines changed

vllm/config/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
334334
self.get_next_dp_init_port(),
335335
self.data_parallel_rank,
336336
self.data_parallel_size,
337-
backend="gloo",
337+
backend=current_platform.dist_backend,
338338
)
339339
except DistNetworkError as e:
340340
# We only want to retry when the root cause is EADDRINUSE.

vllm/distributed/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,6 @@ def create(
415415

416416

417417
def init_gloo_process_group(
418-
backend: Backend,
419418
prefix_store: PrefixStore,
420419
group_rank: int,
421420
group_size: int,
@@ -432,7 +431,7 @@ def init_gloo_process_group(
432431
group_size,
433432
)
434433
else:
435-
options = ProcessGroup.Options(backend=backend)
434+
options = ProcessGroup.Options(backend="gloo")
436435
pg = ProcessGroup(
437436
prefix_store,
438437
group_rank,
@@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group(
504503
# Use a PrefixStore to avoid accidental overrides of keys used by
505504
# different systems (e.g. RPC) in case the store is multi-tenant.
506505
prefix_store = PrefixStore(init_method, store)
506+
try:
507+
from vllm.platforms import current_platform
507508

508-
if backend == "gloo":
509-
return init_gloo_process_group(
509+
return current_platform.stateless_init_device_torch_dist_pg(
510510
backend=backend,
511511
prefix_store=prefix_store,
512512
group_rank=group_rank,
513513
group_size=group_size,
514514
timeout=timeout,
515515
)
516-
from vllm.platforms import current_platform
517-
518-
return current_platform.stateless_init_device_torch_dist_pg(
519-
backend=backend,
520-
prefix_store=prefix_store,
521-
group_rank=group_rank,
522-
group_size=group_size,
523-
timeout=timeout,
524-
)
516+
except NotImplementedError:
517+
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
518+
# will raise a NotImplementedError. In this case, we fall back to gloo.
519+
return init_gloo_process_group(
520+
prefix_store=prefix_store,
521+
group_rank=group_rank,
522+
group_size=group_size,
523+
timeout=timeout,
524+
)
525525

526526

527527
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:

vllm/platforms/cuda.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66

77
import os
88
from collections.abc import Callable
9-
from datetime import timedelta
109
from functools import cache, wraps
1110
from typing import TYPE_CHECKING, TypeVar
1211

1312
import torch
14-
from torch.distributed import PrefixStore, ProcessGroup
15-
from torch.distributed.distributed_c10d import is_nccl_available
1613
from typing_extensions import ParamSpec
1714

1815
# import custom ops, trigger op registration
@@ -455,37 +452,6 @@ def opaque_attention_op(cls) -> bool:
455452
def get_static_graph_wrapper_cls(cls) -> str:
456453
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
457454

458-
@classmethod
459-
def stateless_init_device_torch_dist_pg(
460-
cls,
461-
backend: str,
462-
prefix_store: PrefixStore,
463-
group_rank: int,
464-
group_size: int,
465-
timeout: timedelta,
466-
) -> ProcessGroup:
467-
assert is_nccl_available()
468-
pg: ProcessGroup = ProcessGroup(
469-
prefix_store,
470-
group_rank,
471-
group_size,
472-
)
473-
from torch.distributed.distributed_c10d import ProcessGroupNCCL
474-
475-
backend_options = ProcessGroupNCCL.Options()
476-
backend_options._timeout = timeout
477-
478-
backend_class = ProcessGroupNCCL(
479-
prefix_store, group_rank, group_size, backend_options
480-
)
481-
backend_type = ProcessGroup.BackendType.NCCL
482-
device = torch.device("cuda")
483-
pg._set_default_backend(backend_type)
484-
backend_class._set_sequence_number_for_group()
485-
486-
pg._register_backend(device, backend_type, backend_class)
487-
return pg
488-
489455
@classmethod
490456
def device_count(cls) -> int:
491457
return cuda_device_count_stateless()

vllm/platforms/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def stateless_init_device_torch_dist_pg(
551551
"""
552552
Init platform-specific torch distributed process group.
553553
"""
554-
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
554+
raise NotImplementedError
555555

556556
@classmethod
557557
def is_kv_cache_dtype_supported(

vllm/platforms/rocm.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import os
5-
from datetime import timedelta
65
from functools import cache, lru_cache, wraps
76
from typing import TYPE_CHECKING
87

98
import torch
10-
from torch.distributed import PrefixStore, ProcessGroup
11-
from torch.distributed.distributed_c10d import is_nccl_available
129

1310
import vllm.envs as envs
1411
from vllm.logger import init_logger
@@ -476,37 +473,6 @@ def is_navi(cls) -> bool:
476473
def get_static_graph_wrapper_cls(cls) -> str:
477474
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
478475

479-
@classmethod
480-
def stateless_init_device_torch_dist_pg(
481-
cls,
482-
backend: str,
483-
prefix_store: PrefixStore,
484-
group_rank: int,
485-
group_size: int,
486-
timeout: timedelta,
487-
) -> ProcessGroup:
488-
assert is_nccl_available()
489-
pg: ProcessGroup = ProcessGroup(
490-
prefix_store,
491-
group_rank,
492-
group_size,
493-
)
494-
from torch.distributed.distributed_c10d import ProcessGroupNCCL
495-
496-
backend_options = ProcessGroupNCCL.Options()
497-
backend_options._timeout = timeout
498-
499-
backend_class = ProcessGroupNCCL(
500-
prefix_store, group_rank, group_size, backend_options
501-
)
502-
backend_type = ProcessGroup.BackendType.NCCL
503-
device = torch.device("cuda")
504-
pg._set_default_backend(backend_type)
505-
backend_class._set_sequence_number_for_group()
506-
507-
pg._register_backend(device, backend_type, backend_class)
508-
return pg
509-
510476
@classmethod
511477
def device_count(cls) -> int:
512478
return cuda_device_count_stateless()

0 commit comments

Comments
 (0)