Skip to content

Commit acc53c7

Browse files
wenscarltlrmchlsmthmgoin
authored andcommitted
Enable Allgather/ReduceScatter backend for NaiveAllToAll (vllm-project#23964)
Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent cb481b4 commit acc53c7

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

vllm/distributed/device_communicators/all2all.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.distributed as dist
77

8+
from vllm.distributed import get_dp_group
89
from vllm.forward_context import get_forward_context
910
from vllm.logger import init_logger
1011
from vllm.utils import has_deep_ep, has_pplx
@@ -69,6 +70,44 @@ def destroy(self):
6970
pass
7071

7172

73+
class AgRsAll2AllManager(All2AllManagerBase):
74+
"""
75+
An implementation of all2all communication based on
76+
all-gather (dispatch) and reduce-scatter (combine).
77+
"""
78+
79+
def __init__(self, cpu_group):
80+
super().__init__(cpu_group)
81+
82+
def dispatch(self, hidden_states: torch.Tensor,
83+
router_logits: torch.Tensor):
84+
"""
85+
Gather hidden_states and router_logits from all dp ranks.
86+
"""
87+
sizes = get_forward_context(
88+
).dp_metadata.get_chunk_sizes_across_dp_rank()
89+
hidden_states, router_logits = get_dp_group().all_gatherv(
90+
[hidden_states, router_logits],
91+
dim=0,
92+
sizes=sizes,
93+
)
94+
return hidden_states, router_logits
95+
96+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
97+
"""
98+
Reduce-scatter hidden_states across all dp ranks.
99+
"""
100+
sizes = get_forward_context(
101+
).dp_metadata.get_chunk_sizes_across_dp_rank()
102+
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
103+
dim=0,
104+
sizes=sizes)
105+
return hidden_states
106+
107+
def destroy(self):
108+
pass
109+
110+
72111
class PPLXAll2AllManager(All2AllManagerBase):
73112
"""
74113
All2All communication based on PPLX kernels.

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def __init__(self,
8787
from .all2all import NaiveAll2AllManager
8888
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
8989
logger.info("Using naive all2all manager.")
90+
elif all2all_backend == "allgather_reducescatter":
91+
from .all2all import AgRsAll2AllManager
92+
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
93+
logger.info("Using AllGather-ReduceScatter all2all manager.")
9094
elif all2all_backend == "pplx":
9195
from .all2all import PPLXAll2AllManager
9296
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)

vllm/envs.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,11 @@
149149
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
150150
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
151151
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
152-
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput",
153-
"deepep_low_latency"] = "naive"
152+
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
153+
"deepep_high_throughput",
154+
"deepep_low_latency",
155+
"allgather_reducescatter"] = \
156+
"allgather_reducescatter"
154157
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
155158
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
156159
VLLM_SLEEP_WHEN_IDLE: bool = False
@@ -1124,14 +1127,18 @@ def get_vllm_port() -> Optional[int]:
11241127

11251128
# all2all backend for vllm's expert parallel communication
11261129
# Available options:
1127-
# - "naive": naive all2all implementation using all-reduce
1130+
# - "naive": naive all2all implementation using broadcasts
1131+
# - "allgather_reducescatter": all2all implementation based on allgather and
1132+
# reducescatter
11281133
# - "pplx": use pplx kernels
11291134
# - "deepep_high_throughput", use deepep high-throughput kernels
11301135
# - "deepep_low_latency", use deepep low-latency kernels
11311136
"VLLM_ALL2ALL_BACKEND":
1132-
env_with_choices("VLLM_ALL2ALL_BACKEND", "naive",
1137+
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
11331138
["naive", "pplx",
1134-
"deepep_high_throughput", "deepep_low_latency"]),
1139+
"deepep_high_throughput",
1140+
"deepep_low_latency",
1141+
"allgather_reducescatter"]),
11351142

11361143
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
11371144
# Both require compute capability 10.0 or above.

0 commit comments

Comments
 (0)