Skip to content

Commit e94aabe

Browse files
tlrmchlsmthyewentao256
authored andcommitted
[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 1e5e5d7 commit e94aabe

23 files changed

+541
-376
lines changed

vllm/config/parallel.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,24 @@ def stateless_init_dp_group(self) -> ProcessGroup:
279279
assert last_exc is not None
280280
raise last_exc
281281

282+
# The all_reduce at the end of attention (during o_proj) means that
283+
# inputs are replicated across each rank of the tensor parallel group.
284+
# If using expert-parallelism with DeepEP All2All ops, replicated
285+
# tokens results in useless duplicate computation and communication.
286+
#
287+
# In this case, ensure the input to the experts is sequence parallel
288+
# to avoid the excess work.
289+
#
290+
# Not needed for pplx-kernels as it can handle duplicate input tokens.
291+
@property
292+
def use_sequence_parallel_moe(self) -> bool:
293+
return (envs.VLLM_ALL2ALL_BACKEND
294+
in ("allgather_reducescatter", "naive",
295+
"deepep_high_throughput", "deepep_low_latency")
296+
and self.enable_expert_parallel
297+
and self.tensor_parallel_size > 1
298+
and self.data_parallel_size > 1)
299+
282300
@staticmethod
283301
def has_unfinished_dp(dp_group: ProcessGroup,
284302
has_unfinished: bool) -> bool:

vllm/distributed/device_communicators/all2all.py

Lines changed: 82 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.distributed as dist
77

88
import vllm.envs as envs
9-
from vllm.distributed import get_dp_group
9+
from vllm.distributed import get_dp_group, get_ep_group
1010
from vllm.forward_context import get_forward_context
1111
from vllm.logger import init_logger
1212
from vllm.utils import has_deep_ep, has_pplx
@@ -34,41 +34,60 @@ def __init__(self, cpu_group):
3434
super().__init__(cpu_group)
3535

3636
def naive_multicast(self, x: torch.Tensor,
37-
cu_tokens_across_dp_cpu: torch.Tensor):
37+
cu_tokens_across_sp_cpu: torch.Tensor,
38+
is_sequence_parallel: bool) -> torch.Tensor:
3839
assert (len(x.shape) == 2)
39-
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
40+
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
4041
device=x.device,
4142
dtype=x.dtype)
4243

43-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
44-
self.dp_rank - 1]
45-
end = cu_tokens_across_dp_cpu[self.dp_rank]
44+
rank = self.rank if is_sequence_parallel else self.dp_rank
45+
world_size = (self.world_size
46+
if is_sequence_parallel else self.dp_world_size)
47+
48+
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
49+
end = cu_tokens_across_sp_cpu[rank]
4650
buffer[start:end, :].copy_(x)
47-
for idx in range(self.dp_world_size):
48-
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
49-
end = cu_tokens_across_dp_cpu[idx]
50-
self.dp_group.broadcast(buffer[start:end, :], idx)
51+
for idx in range(world_size):
52+
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
53+
end = cu_tokens_across_sp_cpu[idx]
54+
get_ep_group().broadcast(buffer[start:end, :], idx)
5155

5256
return buffer
5357

54-
def dispatch(self, hidden_states: torch.Tensor,
55-
router_logits: torch.Tensor):
56-
sizes = get_forward_context(
57-
).dp_metadata.get_chunk_sizes_across_dp_rank()
58-
hidden_states, router_logits = get_dp_group().all_gatherv(
59-
[hidden_states, router_logits],
60-
dim=0,
61-
sizes=sizes,
62-
)
63-
58+
def dispatch(
59+
self,
60+
hidden_states: torch.Tensor,
61+
router_logits: torch.Tensor,
62+
is_sequence_parallel: bool = False
63+
) -> tuple[torch.Tensor, torch.Tensor]:
64+
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
65+
dp_metadata = get_forward_context().dp_metadata
66+
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
67+
68+
hidden_states = self.naive_multicast(hidden_states,
69+
cu_tokens_across_sp_cpu,
70+
is_sequence_parallel)
71+
router_logits = self.naive_multicast(router_logits,
72+
cu_tokens_across_sp_cpu,
73+
is_sequence_parallel)
6474
return hidden_states, router_logits
6575

66-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
67-
sizes = get_forward_context(
68-
).dp_metadata.get_chunk_sizes_across_dp_rank()
69-
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
70-
dim=0,
71-
sizes=sizes)
76+
def combine(self,
77+
hidden_states: torch.Tensor,
78+
is_sequence_parallel: bool = False) -> torch.Tensor:
79+
80+
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
81+
82+
dp_metadata = get_forward_context().dp_metadata
83+
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
84+
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
85+
86+
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
87+
end = cu_tokens_across_sp_cpu[ep_rank]
88+
89+
all_hidden_states = get_ep_group().all_reduce(hidden_states)
90+
hidden_states = all_hidden_states[start:end, :]
7291
return hidden_states
7392

7493
def destroy(self):
@@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
84103
def __init__(self, cpu_group):
85104
super().__init__(cpu_group)
86105

87-
def dispatch(self, hidden_states: torch.Tensor,
88-
router_logits: torch.Tensor):
106+
def dispatch(
107+
self,
108+
hidden_states: torch.Tensor,
109+
router_logits: torch.Tensor,
110+
is_sequence_parallel: bool = False
111+
) -> tuple[torch.Tensor, torch.Tensor]:
89112
"""
90113
Gather hidden_states and router_logits from all dp ranks.
91114
"""
92115
sizes = get_forward_context(
93116
).dp_metadata.get_chunk_sizes_across_dp_rank()
94-
hidden_states, router_logits = get_dp_group().all_gatherv(
117+
118+
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
119+
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
120+
hidden_states, router_logits = dist_group.all_gatherv(
95121
[hidden_states, router_logits],
96122
dim=0,
97123
sizes=sizes,
98124
)
99125
return hidden_states, router_logits
100126

101-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
127+
def combine(self,
128+
hidden_states: torch.Tensor,
129+
is_sequence_parallel: bool = False) -> torch.Tensor:
102130
"""
103131
Reduce-scatter hidden_states across all dp ranks.
104132
"""
105133
sizes = get_forward_context(
106134
).dp_metadata.get_chunk_sizes_across_dp_rank()
107-
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
108-
dim=0,
109-
sizes=sizes)
135+
136+
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
137+
hidden_states = dist_group.reduce_scatterv(hidden_states,
138+
dim=0,
139+
sizes=sizes)
110140
return hidden_states
111141

112142
def destroy(self):
@@ -148,11 +178,17 @@ def get_handle(self, kwargs):
148178
kwargs, pplx.AllToAll.internode
149179
if self.internode else pplx.AllToAll.intranode)
150180

151-
def dispatch(self, hidden_states: torch.Tensor,
152-
router_logits: torch.Tensor):
181+
def dispatch(
182+
self,
183+
hidden_states: torch.Tensor,
184+
router_logits: torch.Tensor,
185+
is_sequence_parallel: bool = False
186+
) -> tuple[torch.Tensor, torch.Tensor]:
153187
raise NotImplementedError
154188

155-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
189+
def combine(self,
190+
hidden_states: torch.Tensor,
191+
is_sequence_parallel: bool = False) -> torch.Tensor:
156192
raise NotImplementedError
157193

158194
def destroy(self):
@@ -184,11 +220,17 @@ def __init__(self, cpu_group):
184220
def get_handle(self, kwargs):
185221
raise NotImplementedError
186222

187-
def dispatch(self, hidden_states: torch.Tensor,
188-
router_logits: torch.Tensor):
223+
def dispatch(
224+
self,
225+
hidden_states: torch.Tensor,
226+
router_logits: torch.Tensor,
227+
is_sequence_parallel: bool = False
228+
) -> tuple[torch.Tensor, torch.Tensor]:
189229
raise NotImplementedError
190230

191-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
231+
def combine(self,
232+
hidden_states: torch.Tensor,
233+
is_sequence_parallel: bool = False) -> torch.Tensor:
192234
raise NotImplementedError
193235

194236
def destroy(self):
@@ -395,4 +437,4 @@ def cleanup(self):
395437
self.workspace_tensor = None
396438
self.prepare_workspace_tensor = None
397439
self.mapping = None
398-
self.initialized = False
440+
self.initialized = False

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def get_or_create(self, kwargs, func):
2828

2929

3030
class All2AllManagerBase:
31+
rank: int
32+
world_size: int
3133

3234
def __init__(self, cpu_group):
3335
self.cpu_group = cpu_group
@@ -40,6 +42,7 @@ def __init__(self, cpu_group):
4042
# all2all lives in ep group, which is merged from dp and tp group
4143
self.dp_group = get_dp_group()
4244
self.tp_group = get_tp_group()
45+
4346
# no self.ep_group since self.ep_group is still in construction
4447
# when we create this object
4548
self.dp_rank = self.dp_group.rank_in_group
@@ -60,17 +63,21 @@ def get_handle(self, kwargs):
6063
# and reuse it for the same config.
6164
raise NotImplementedError
6265

66+
def dispatch(self,
67+
hidden_states: torch.Tensor,
68+
router_logits: torch.Tensor,
69+
is_sequence_parallel: bool = False):
70+
raise NotImplementedError
71+
6372
def set_num_sms(self, num_sms: int):
6473
pass
6574

6675
def max_sms_used(self) -> Optional[int]:
6776
return None # None means it could use the whole GPU
6877

69-
def dispatch(self, hidden_states: torch.Tensor,
70-
router_logits: torch.Tensor):
71-
raise NotImplementedError
72-
73-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
78+
def combine(self,
79+
hidden_states: torch.Tensor,
80+
is_sequence_parallel: bool = False):
7481
raise NotImplementedError
7582

7683
def destroy(self):
@@ -267,15 +274,20 @@ def prepare_communication_buffer_for_model(self,
267274
module.quant_method.init_prepare_finalize(module)
268275

269276
def dispatch(
270-
self, hidden_states: torch.Tensor,
271-
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
277+
self,
278+
hidden_states: torch.Tensor,
279+
router_logits: torch.Tensor,
280+
is_sequence_parallel: bool = False
281+
) -> tuple[torch.Tensor, torch.Tensor]:
272282
"""
273283
Dispatch the hidden states and router logits to the appropriate device.
274284
This is a no-op in the base class.
275285
"""
276286
return hidden_states, router_logits
277287

278-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
288+
def combine(self,
289+
hidden_states: torch.Tensor,
290+
is_sequence_parallel: bool = False) -> torch.Tensor:
279291
"""
280292
Combine the hidden states and router logits from the appropriate device.
281293
This is a no-op in the base class.

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def __init__(self,
3939
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
4040
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
4141

42-
# ep does not use pynccl
43-
use_pynccl = "ep" not in unique_name
44-
45-
self.use_pynccl = use_pynccl
4642
self.use_custom_allreduce = use_custom_allreduce
4743
self.use_torch_symm_mem = use_torch_symm_mem
4844

@@ -57,7 +53,7 @@ def __init__(self,
5753
SymmMemCommunicator)
5854

5955
self.pynccl_comm: Optional[PyNcclCommunicator] = None
60-
if use_pynccl and self.world_size > 1:
56+
if self.world_size > 1:
6157
self.pynccl_comm = PyNcclCommunicator(
6258
group=self.cpu_group,
6359
device=self.device,
@@ -308,14 +304,20 @@ def _all_gather_single(input_: torch.Tensor,
308304
return output_list
309305

310306
def dispatch(
311-
self, hidden_states: torch.Tensor,
312-
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
307+
self,
308+
hidden_states: torch.Tensor,
309+
router_logits: torch.Tensor,
310+
is_sequence_parallel: bool = False
311+
) -> tuple[torch.Tensor, torch.Tensor]:
313312
assert self.all2all_manager is not None
314313
hidden_states, router_logits = self.all2all_manager.dispatch(
315-
hidden_states, router_logits)
314+
hidden_states, router_logits, is_sequence_parallel)
316315
return hidden_states, router_logits
317316

318-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
317+
def combine(self,
318+
hidden_states: torch.Tensor,
319+
is_sequence_parallel: bool = False) -> torch.Tensor:
319320
assert self.all2all_manager is not None
320-
hidden_states = self.all2all_manager.combine(hidden_states)
321+
hidden_states = self.all2all_manager.combine(hidden_states,
322+
is_sequence_parallel)
321323
return hidden_states

vllm/distributed/device_communicators/xpu_communicator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,20 @@ def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
7575
dist.broadcast(input_, src=src, group=self.device_group)
7676

7777
def dispatch(
78-
self, hidden_states: torch.Tensor,
79-
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
78+
self,
79+
hidden_states: torch.Tensor,
80+
router_logits: torch.Tensor,
81+
is_sequence_parallel: bool = False
82+
) -> tuple[torch.Tensor, torch.Tensor]:
8083
assert self.all2all_manager is not None
8184
hidden_states, router_logits = self.all2all_manager.dispatch(
82-
hidden_states, router_logits)
85+
hidden_states, router_logits, is_sequence_parallel)
8386
return hidden_states, router_logits
8487

85-
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
88+
def combine(self,
89+
hidden_states: torch.Tensor,
90+
is_sequence_parallel: bool = False) -> torch.Tensor:
8691
assert self.all2all_manager is not None
87-
hidden_states = self.all2all_manager.combine(hidden_states)
92+
hidden_states = self.all2all_manager.combine(hidden_states,
93+
is_sequence_parallel)
8894
return hidden_states

vllm/distributed/parallel_state.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -871,17 +871,24 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
871871
model)
872872

873873
def dispatch(
874-
self, hidden_states: torch.Tensor,
875-
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
874+
self,
875+
hidden_states: torch.Tensor,
876+
router_logits: torch.Tensor,
877+
is_sequence_parallel: bool = False
878+
) -> tuple[torch.Tensor, torch.Tensor]:
876879
if self.device_communicator is not None:
877880
return self.device_communicator.dispatch(hidden_states,
878-
router_logits)
881+
router_logits,
882+
is_sequence_parallel)
879883
else:
880884
return hidden_states, router_logits
881885

882-
def combine(self, hidden_states) -> torch.Tensor:
886+
def combine(self,
887+
hidden_states,
888+
is_sequence_parallel: bool = False) -> torch.Tensor:
883889
if self.device_communicator is not None:
884-
return self.device_communicator.combine(hidden_states)
890+
return self.device_communicator.combine(hidden_states,
891+
is_sequence_parallel)
885892
else:
886893
return hidden_states
887894

0 commit comments

Comments
 (0)