Skip to content

Commit f79c6a1

Browse files
tlrmchlsmtheicherseiji
authored andcommitted
[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE (vllm-project#24134)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
1 parent 1d8a4ca commit f79c6a1

File tree

4 files changed

+132
-59
lines changed

4 files changed

+132
-59
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm.model_executor.utils import set_weight_attrs
3636
from vllm.platforms import current_platform
3737
from vllm.platforms.interface import CpuArchEnum
38-
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
38+
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
3939
round_up)
4040

4141
if current_platform.is_cuda_alike():
@@ -786,6 +786,7 @@ def __init__(
786786
enable_eplb: bool = False,
787787
num_redundant_experts: int = 0,
788788
has_bias: bool = False,
789+
is_sequence_parallel=False,
789790
):
790791
super().__init__()
791792
if params_dtype is None:
@@ -797,6 +798,10 @@ def __init__(
797798
dp_size_ = (dp_size
798799
if dp_size is not None else get_dp_group().world_size)
799800

801+
self.is_sequence_parallel = is_sequence_parallel
802+
if self.is_sequence_parallel:
803+
self.sp_size = tp_size_
804+
800805
vllm_config = get_current_vllm_config()
801806
self.moe_parallel_config: FusedMoEParallelConfig = (
802807
FusedMoEParallelConfig.make(
@@ -1699,14 +1704,22 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
16991704

17001705
ctx = get_forward_context()
17011706
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
1702-
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
1707+
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
17031708
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
1709+
1710+
# If the input to the MoE is sequence parallel then divide by sp_size
1711+
# to find the maximum number of tokens for any individual dispatcher.
1712+
if self.is_sequence_parallel:
1713+
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
1714+
self.sp_size)
1715+
17041716
num_tokens = full_hidden_states.size(0)
17051717
for chunk_idx, chunk_start_ in enumerate(
1706-
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
1718+
range(0, max_tokens_across_dispatchers,
1719+
moe_dp_chunk_size_per_rank)):
17071720
chunk_start = chunk_start_
17081721
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
1709-
max_tokens_across_dp)
1722+
max_tokens_across_dispatchers)
17101723
# clamp start and end
17111724
chunk_start = min(chunk_start, num_tokens - 1)
17121725
chunk_end = min(chunk_end, num_tokens)

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(
3737
super().__init__()
3838
self.config = vllm_config. \
3939
speculative_config.draft_model_config.hf_config
40-
model_config = vllm_config.model_config
41-
cache_config = vllm_config.cache_config
4240
quant_config = vllm_config.quant_config
4341
self.vocab_size = self.config.vocab_size
4442

@@ -51,11 +49,8 @@ def __init__(
5149

5250
self.layers = nn.ModuleList([
5351
DeepseekV2DecoderLayer(
54-
self.config,
52+
vllm_config,
5553
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
56-
model_config=model_config,
57-
cache_config=cache_config,
58-
quant_config=quant_config,
5954
) for i in range(self.config.num_hidden_layers)
6055
])
6156

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from transformers import PretrainedConfig
99

10-
from vllm.config import CacheConfig, ModelConfig, VllmConfig
10+
from vllm.config import VllmConfig
1111
from vllm.model_executor.layers.fused_moe import FusedMoE
1212
from vllm.model_executor.layers.layernorm import RMSNorm
1313
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -43,23 +43,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4343

4444
class DeepSeekMultiTokenPredictorLayer(nn.Module):
4545

46-
def __init__(
47-
self,
48-
config: PretrainedConfig,
49-
prefix: str,
50-
model_config: ModelConfig,
51-
cache_config: Optional[CacheConfig] = None,
52-
quant_config: Optional[QuantizationConfig] = None,
53-
) -> None:
46+
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
5447
super().__init__()
48+
49+
config = vllm_config.model_config.hf_config
50+
quant_config = vllm_config.quant_config
51+
5552
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5653
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5754
self.eh_proj = nn.Linear(config.hidden_size * 2,
5855
config.hidden_size,
5956
bias=False)
6057
self.shared_head = SharedHead(config=config, quant_config=quant_config)
61-
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
62-
cache_config, quant_config)
58+
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
6359

6460
def forward(
6561
self,
@@ -95,13 +91,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9591
# to map the exact layer index from weights
9692
self.layers = torch.nn.ModuleDict({
9793
str(idx):
98-
DeepSeekMultiTokenPredictorLayer(
99-
config,
100-
f"{prefix}.layers.{idx}",
101-
model_config=vllm_config.model_config,
102-
cache_config=vllm_config.cache_config,
103-
quant_config=vllm_config.quant_config,
104-
)
94+
DeepSeekMultiTokenPredictorLayer(vllm_config,
95+
f"{prefix}.layers.{idx}")
10596
for idx in range(self.mtp_start_layer_idx,
10697
self.mtp_start_layer_idx + self.num_mtp_layers)
10798
})

vllm/model_executor/models/deepseek_v2.py

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232
from torch import nn
3333
from transformers import DeepseekV2Config, DeepseekV3Config
3434

35+
import vllm.envs as envs
3536
from vllm.attention import Attention
3637
from vllm.compilation.decorators import support_torch_compile
37-
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
38-
get_current_vllm_config)
38+
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
3939
from vllm.distributed import (get_ep_group, get_pp_group,
40-
get_tensor_model_parallel_world_size)
40+
get_tensor_model_parallel_rank,
41+
get_tensor_model_parallel_world_size,
42+
tensor_model_parallel_all_gather)
4143
from vllm.model_executor.layers.activation import SiluAndMul
4244
from vllm.model_executor.layers.fused_moe import FusedMoE
4345
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -55,7 +57,9 @@
5557
from vllm.model_executor.model_loader.weight_utils import (
5658
default_weight_loader, maybe_remap_kv_scale_name)
5759
from vllm.model_executor.sampling_metadata import SamplingMetadata
60+
from vllm.platforms import current_platform
5861
from vllm.sequence import IntermediateTensors
62+
from vllm.utils import cdiv, direct_register_custom_op
5963

6064
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
6165
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@@ -72,19 +76,27 @@ def __init__(
7276
hidden_act: str,
7377
quant_config: Optional[QuantizationConfig] = None,
7478
reduce_results: bool = True,
79+
is_sequence_parallel=False,
7580
prefix: str = "",
7681
) -> None:
7782
super().__init__()
83+
84+
# If is_sequence_parallel, the input and output tensors are sharded
85+
# across the ranks within the tp_group. In this case the weights are
86+
# replicated and no collective ops are needed.
87+
# Otherwise we use standard TP with an allreduce at the end.
7888
self.gate_up_proj = MergedColumnParallelLinear(
7989
hidden_size, [intermediate_size] * 2,
8090
bias=False,
8191
quant_config=quant_config,
92+
disable_tp=is_sequence_parallel,
8293
prefix=f"{prefix}.gate_up_proj")
8394
self.down_proj = RowParallelLinear(intermediate_size,
8495
hidden_size,
8596
bias=False,
8697
quant_config=quant_config,
8798
reduce_results=reduce_results,
99+
disable_tp=is_sequence_parallel,
88100
prefix=f"{prefix}.down_proj")
89101
if hidden_act != "silu":
90102
raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -98,17 +110,58 @@ def forward(self, x):
98110
return x
99111

100112

113+
# Chunk x along the num_tokens axis for sequence parallelism
114+
# NOTE: This is wrapped in a torch custom op to work around the following issue:
115+
# The output tensor can have a sequence length 0 at small input sequence lengths
116+
# even though we explicitly pad to avoid this.
117+
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
118+
tp_size = get_tensor_model_parallel_world_size()
119+
tp_rank = get_tensor_model_parallel_rank()
120+
121+
# all_gather needs the sequence length to be divisible by tp_size
122+
seq_len = x.size(0)
123+
remainder = seq_len % tp_size
124+
if remainder != 0:
125+
pad_len = tp_size - remainder
126+
x = nn.functional.pad(x, (0, 0, 0, pad_len))
127+
128+
chunk = x.shape[0] // tp_size
129+
start = tp_rank * chunk
130+
return torch.narrow(x, 0, start, chunk)
131+
132+
133+
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
134+
tp_size = get_tensor_model_parallel_world_size()
135+
seq_len = cdiv(x.size(0), tp_size)
136+
shape = list(x.shape)
137+
shape[0] = seq_len
138+
out = torch.empty(shape, dtype=x.dtype, device=x.device)
139+
return out
140+
141+
142+
direct_register_custom_op(
143+
op_name="sequence_parallel_chunk",
144+
op_func=sequence_parallel_chunk,
145+
mutates_args=[],
146+
fake_impl=sequence_parallel_chunk_fake,
147+
dispatch_key=current_platform.dispatch_key,
148+
tags=(torch.Tag.needs_fixed_stride_order, ),
149+
)
150+
151+
101152
class DeepseekV2MoE(nn.Module):
102153

103154
def __init__(
104155
self,
105156
config: Union[DeepseekV2Config, DeepseekV3Config],
157+
parallel_config: ParallelConfig,
106158
quant_config: Optional[QuantizationConfig] = None,
107159
prefix: str = "",
108-
enable_eplb: bool = False,
109160
):
110161
super().__init__()
111162
self.tp_size = get_tensor_model_parallel_world_size()
163+
self.tp_rank = get_tensor_model_parallel_rank()
164+
112165
self.routed_scaling_factor = config.routed_scaling_factor
113166

114167
self.ep_group = get_ep_group().device_group
@@ -117,6 +170,21 @@ def __init__(
117170
self.n_routed_experts: int = config.n_routed_experts
118171
self.n_shared_experts: int = config.n_shared_experts
119172

173+
# The all_reduce at the end of attention (during o_proj) means that
174+
# inputs are replicated across each rank of the tensor parallel group.
175+
# If using expert-parallelism with DeepEP All2All ops, replicated
176+
# tokens results in useless duplicate computation and communication.
177+
#
178+
# In this case, ensure the input to the experts is sequence parallel
179+
# to avoid the excess work.
180+
#
181+
# Not needed for pplx-kernels as it can handle duplicate input tokens.
182+
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
183+
in ("deepep_high_throughput",
184+
"deepep_low_latency")
185+
and parallel_config.enable_expert_parallel
186+
and self.tp_size > 1)
187+
120188
if config.hidden_act != "silu":
121189
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
122190
"Only silu is supported for now.")
@@ -133,9 +201,8 @@ def __init__(
133201
self.gate.e_score_correction_bias = None
134202

135203
# Load balancing settings.
136-
vllm_config = get_current_vllm_config()
137-
eplb_config = vllm_config.parallel_config.eplb_config
138-
self.enable_eplb = enable_eplb
204+
eplb_config = parallel_config.eplb_config
205+
self.enable_eplb = parallel_config.enable_eplb
139206

140207
self.n_redundant_experts = eplb_config.num_redundant_experts
141208
self.n_logical_experts = self.n_routed_experts
@@ -166,7 +233,9 @@ def __init__(
166233
routed_scaling_factor=1.0,
167234
e_score_correction_bias=self.gate.e_score_correction_bias,
168235
enable_eplb=self.enable_eplb,
169-
num_redundant_experts=self.n_redundant_experts)
236+
num_redundant_experts=self.n_redundant_experts,
237+
is_sequence_parallel=self.is_sequence_parallel,
238+
)
170239
self.shared_experts = None
171240
else:
172241
intermediate_size = (config.moe_intermediate_size *
@@ -177,6 +246,7 @@ def __init__(
177246
intermediate_size=intermediate_size,
178247
hidden_act=config.hidden_act,
179248
quant_config=quant_config,
249+
is_sequence_parallel=self.is_sequence_parallel,
180250
reduce_results=False,
181251
prefix=f"{prefix}.shared_experts",
182252
)
@@ -199,11 +269,22 @@ def __init__(
199269
routed_scaling_factor=1.0,
200270
e_score_correction_bias=self.gate.e_score_correction_bias,
201271
enable_eplb=self.enable_eplb,
202-
num_redundant_experts=self.n_redundant_experts)
272+
num_redundant_experts=self.n_redundant_experts,
273+
is_sequence_parallel=self.is_sequence_parallel,
274+
)
203275

204276
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
205277
num_tokens, hidden_dim = hidden_states.shape
206278
hidden_states = hidden_states.view(-1, hidden_dim)
279+
280+
# Chunk the hidden states so they aren't replicated across TP ranks.
281+
# This avoids duplicate computation in self.experts.
282+
# TODO: We can replace the all_reduce at the end of attn with a
283+
# reduce_scatter instead of chunking here.
284+
if self.is_sequence_parallel:
285+
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
286+
hidden_states)
287+
207288
# router_logits: (num_tokens, n_experts)
208289
router_logits, _ = self.gate(hidden_states)
209290

@@ -228,7 +309,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228309
assert shared_output is not None
229310
final_hidden_states += shared_output
230311

231-
if self.tp_size > 1:
312+
if self.is_sequence_parallel:
313+
final_hidden_states = tensor_model_parallel_all_gather(
314+
final_hidden_states, 0)
315+
final_hidden_states = final_hidden_states[:num_tokens]
316+
elif self.tp_size > 1:
232317
final_hidden_states = (
233318
self.experts.maybe_all_reduce_tensor_model_parallel(
234319
final_hidden_states))
@@ -532,16 +617,15 @@ def forward(
532617

533618
class DeepseekV2DecoderLayer(nn.Module):
534619

535-
def __init__(
536-
self,
537-
config: Union[DeepseekV2Config, DeepseekV3Config],
538-
prefix: str,
539-
model_config: ModelConfig,
540-
cache_config: Optional[CacheConfig] = None,
541-
quant_config: Optional[QuantizationConfig] = None,
542-
enable_eplb: bool = False,
543-
) -> None:
620+
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
544621
super().__init__()
622+
623+
config = vllm_config.model_config.hf_config
624+
model_config = vllm_config.model_config
625+
cache_config = vllm_config.cache_config
626+
quant_config = vllm_config.quant_config
627+
parallel_config = vllm_config.parallel_config
628+
545629
self.hidden_size = config.hidden_size
546630
rope_theta = getattr(config, "rope_theta", 10000)
547631
rope_scaling = getattr(config, "rope_scaling", None)
@@ -578,9 +662,9 @@ def __init__(
578662
and layer_idx % config.moe_layer_freq == 0):
579663
self.mlp = DeepseekV2MoE(
580664
config=config,
665+
parallel_config=parallel_config,
581666
quant_config=quant_config,
582667
prefix=f"{prefix}.mlp",
583-
enable_eplb=enable_eplb,
584668
)
585669
else:
586670
self.mlp = DeepseekV2MLP(
@@ -650,10 +734,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
650734
super().__init__()
651735

652736
config = vllm_config.model_config.hf_config
653-
model_config = vllm_config.model_config
654-
cache_config = vllm_config.cache_config
655737
quant_config = vllm_config.quant_config
656-
enable_eplb = vllm_config.parallel_config.enable_eplb
657738
self.config = config
658739

659740
self.vocab_size = config.vocab_size
@@ -669,14 +750,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
669750

670751
self.start_layer, self.end_layer, self.layers = make_layers(
671752
config.num_hidden_layers,
672-
lambda prefix: DeepseekV2DecoderLayer(
673-
config,
674-
prefix,
675-
model_config=model_config,
676-
cache_config=cache_config,
677-
quant_config=quant_config,
678-
enable_eplb=enable_eplb,
679-
),
753+
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
680754
prefix=f"{prefix}.layers")
681755

682756
if get_pp_group().is_last_rank:

0 commit comments

Comments
 (0)