3232from torch import nn
3333from transformers import DeepseekV2Config , DeepseekV3Config
3434
35+ import vllm .envs as envs
3536from vllm .attention import Attention
3637from vllm .compilation .decorators import support_torch_compile
37- from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
38- get_current_vllm_config )
38+ from vllm .config import CacheConfig , ParallelConfig , VllmConfig
3939from vllm .distributed import (get_ep_group , get_pp_group ,
40- get_tensor_model_parallel_world_size )
40+ get_tensor_model_parallel_rank ,
41+ get_tensor_model_parallel_world_size ,
42+ tensor_model_parallel_all_gather )
4143from vllm .model_executor .layers .activation import SiluAndMul
4244from vllm .model_executor .layers .fused_moe import FusedMoE
4345from vllm .model_executor .layers .layernorm import RMSNorm
5557from vllm .model_executor .model_loader .weight_utils import (
5658 default_weight_loader , maybe_remap_kv_scale_name )
5759from vllm .model_executor .sampling_metadata import SamplingMetadata
60+ from vllm .platforms import current_platform
5861from vllm .sequence import IntermediateTensors
62+ from vllm .utils import cdiv , direct_register_custom_op
5963
6064from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
6165from .utils import (PPMissingLayer , is_pp_missing_parameter ,
@@ -72,19 +76,27 @@ def __init__(
7276 hidden_act : str ,
7377 quant_config : Optional [QuantizationConfig ] = None ,
7478 reduce_results : bool = True ,
79+ is_sequence_parallel = False ,
7580 prefix : str = "" ,
7681 ) -> None :
7782 super ().__init__ ()
83+
84+ # If is_sequence_parallel, the input and output tensors are sharded
85+ # across the ranks within the tp_group. In this case the weights are
86+ # replicated and no collective ops are needed.
87+ # Otherwise we use standard TP with an allreduce at the end.
7888 self .gate_up_proj = MergedColumnParallelLinear (
7989 hidden_size , [intermediate_size ] * 2 ,
8090 bias = False ,
8191 quant_config = quant_config ,
92+ disable_tp = is_sequence_parallel ,
8293 prefix = f"{ prefix } .gate_up_proj" )
8394 self .down_proj = RowParallelLinear (intermediate_size ,
8495 hidden_size ,
8596 bias = False ,
8697 quant_config = quant_config ,
8798 reduce_results = reduce_results ,
99+ disable_tp = is_sequence_parallel ,
88100 prefix = f"{ prefix } .down_proj" )
89101 if hidden_act != "silu" :
90102 raise ValueError (f"Unsupported activation: { hidden_act } . "
@@ -98,17 +110,58 @@ def forward(self, x):
98110 return x
99111
100112
113+ # Chunk x along the num_tokens axis for sequence parallelism
114+ # NOTE: This is wrapped in a torch custom op to work around the following issue:
115+ # The output tensor can have a sequence length 0 at small input sequence lengths
116+ # even though we explicitly pad to avoid this.
117+ def sequence_parallel_chunk (x : torch .Tensor ) -> torch .Tensor :
118+ tp_size = get_tensor_model_parallel_world_size ()
119+ tp_rank = get_tensor_model_parallel_rank ()
120+
121+ # all_gather needs the sequence length to be divisible by tp_size
122+ seq_len = x .size (0 )
123+ remainder = seq_len % tp_size
124+ if remainder != 0 :
125+ pad_len = tp_size - remainder
126+ x = nn .functional .pad (x , (0 , 0 , 0 , pad_len ))
127+
128+ chunk = x .shape [0 ] // tp_size
129+ start = tp_rank * chunk
130+ return torch .narrow (x , 0 , start , chunk )
131+
132+
133+ def sequence_parallel_chunk_fake (x : torch .Tensor ) -> torch .Tensor :
134+ tp_size = get_tensor_model_parallel_world_size ()
135+ seq_len = cdiv (x .size (0 ), tp_size )
136+ shape = list (x .shape )
137+ shape [0 ] = seq_len
138+ out = torch .empty (shape , dtype = x .dtype , device = x .device )
139+ return out
140+
141+
142+ direct_register_custom_op (
143+ op_name = "sequence_parallel_chunk" ,
144+ op_func = sequence_parallel_chunk ,
145+ mutates_args = [],
146+ fake_impl = sequence_parallel_chunk_fake ,
147+ dispatch_key = current_platform .dispatch_key ,
148+ tags = (torch .Tag .needs_fixed_stride_order , ),
149+ )
150+
151+
101152class DeepseekV2MoE (nn .Module ):
102153
103154 def __init__ (
104155 self ,
105156 config : Union [DeepseekV2Config , DeepseekV3Config ],
157+ parallel_config : ParallelConfig ,
106158 quant_config : Optional [QuantizationConfig ] = None ,
107159 prefix : str = "" ,
108- enable_eplb : bool = False ,
109160 ):
110161 super ().__init__ ()
111162 self .tp_size = get_tensor_model_parallel_world_size ()
163+ self .tp_rank = get_tensor_model_parallel_rank ()
164+
112165 self .routed_scaling_factor = config .routed_scaling_factor
113166
114167 self .ep_group = get_ep_group ().device_group
@@ -117,6 +170,21 @@ def __init__(
117170 self .n_routed_experts : int = config .n_routed_experts
118171 self .n_shared_experts : int = config .n_shared_experts
119172
173+ # The all_reduce at the end of attention (during o_proj) means that
174+ # inputs are replicated across each rank of the tensor parallel group.
175+ # If using expert-parallelism with DeepEP All2All ops, replicated
176+ # tokens results in useless duplicate computation and communication.
177+ #
178+ # In this case, ensure the input to the experts is sequence parallel
179+ # to avoid the excess work.
180+ #
181+ # Not needed for pplx-kernels as it can handle duplicate input tokens.
182+ self .is_sequence_parallel = (envs .VLLM_ALL2ALL_BACKEND
183+ in ("deepep_high_throughput" ,
184+ "deepep_low_latency" )
185+ and parallel_config .enable_expert_parallel
186+ and self .tp_size > 1 )
187+
120188 if config .hidden_act != "silu" :
121189 raise ValueError (f"Unsupported activation: { config .hidden_act } . "
122190 "Only silu is supported for now." )
@@ -133,9 +201,8 @@ def __init__(
133201 self .gate .e_score_correction_bias = None
134202
135203 # Load balancing settings.
136- vllm_config = get_current_vllm_config ()
137- eplb_config = vllm_config .parallel_config .eplb_config
138- self .enable_eplb = enable_eplb
204+ eplb_config = parallel_config .eplb_config
205+ self .enable_eplb = parallel_config .enable_eplb
139206
140207 self .n_redundant_experts = eplb_config .num_redundant_experts
141208 self .n_logical_experts = self .n_routed_experts
@@ -166,7 +233,9 @@ def __init__(
166233 routed_scaling_factor = 1.0 ,
167234 e_score_correction_bias = self .gate .e_score_correction_bias ,
168235 enable_eplb = self .enable_eplb ,
169- num_redundant_experts = self .n_redundant_experts )
236+ num_redundant_experts = self .n_redundant_experts ,
237+ is_sequence_parallel = self .is_sequence_parallel ,
238+ )
170239 self .shared_experts = None
171240 else :
172241 intermediate_size = (config .moe_intermediate_size *
@@ -177,6 +246,7 @@ def __init__(
177246 intermediate_size = intermediate_size ,
178247 hidden_act = config .hidden_act ,
179248 quant_config = quant_config ,
249+ is_sequence_parallel = self .is_sequence_parallel ,
180250 reduce_results = False ,
181251 prefix = f"{ prefix } .shared_experts" ,
182252 )
@@ -199,11 +269,22 @@ def __init__(
199269 routed_scaling_factor = 1.0 ,
200270 e_score_correction_bias = self .gate .e_score_correction_bias ,
201271 enable_eplb = self .enable_eplb ,
202- num_redundant_experts = self .n_redundant_experts )
272+ num_redundant_experts = self .n_redundant_experts ,
273+ is_sequence_parallel = self .is_sequence_parallel ,
274+ )
203275
204276 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
205277 num_tokens , hidden_dim = hidden_states .shape
206278 hidden_states = hidden_states .view (- 1 , hidden_dim )
279+
280+ # Chunk the hidden states so they aren't replicated across TP ranks.
281+ # This avoids duplicate computation in self.experts.
282+ # TODO: We can replace the all_reduce at the end of attn with a
283+ # reduce_scatter instead of chunking here.
284+ if self .is_sequence_parallel :
285+ hidden_states = torch .ops .vllm .sequence_parallel_chunk (
286+ hidden_states )
287+
207288 # router_logits: (num_tokens, n_experts)
208289 router_logits , _ = self .gate (hidden_states )
209290
@@ -228,7 +309,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228309 assert shared_output is not None
229310 final_hidden_states += shared_output
230311
231- if self .tp_size > 1 :
312+ if self .is_sequence_parallel :
313+ final_hidden_states = tensor_model_parallel_all_gather (
314+ final_hidden_states , 0 )
315+ final_hidden_states = final_hidden_states [:num_tokens ]
316+ elif self .tp_size > 1 :
232317 final_hidden_states = (
233318 self .experts .maybe_all_reduce_tensor_model_parallel (
234319 final_hidden_states ))
@@ -532,16 +617,15 @@ def forward(
532617
533618class DeepseekV2DecoderLayer (nn .Module ):
534619
535- def __init__ (
536- self ,
537- config : Union [DeepseekV2Config , DeepseekV3Config ],
538- prefix : str ,
539- model_config : ModelConfig ,
540- cache_config : Optional [CacheConfig ] = None ,
541- quant_config : Optional [QuantizationConfig ] = None ,
542- enable_eplb : bool = False ,
543- ) -> None :
620+ def __init__ (self , vllm_config : VllmConfig , prefix : str ) -> None :
544621 super ().__init__ ()
622+
623+ config = vllm_config .model_config .hf_config
624+ model_config = vllm_config .model_config
625+ cache_config = vllm_config .cache_config
626+ quant_config = vllm_config .quant_config
627+ parallel_config = vllm_config .parallel_config
628+
545629 self .hidden_size = config .hidden_size
546630 rope_theta = getattr (config , "rope_theta" , 10000 )
547631 rope_scaling = getattr (config , "rope_scaling" , None )
@@ -578,9 +662,9 @@ def __init__(
578662 and layer_idx % config .moe_layer_freq == 0 ):
579663 self .mlp = DeepseekV2MoE (
580664 config = config ,
665+ parallel_config = parallel_config ,
581666 quant_config = quant_config ,
582667 prefix = f"{ prefix } .mlp" ,
583- enable_eplb = enable_eplb ,
584668 )
585669 else :
586670 self .mlp = DeepseekV2MLP (
@@ -650,10 +734,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
650734 super ().__init__ ()
651735
652736 config = vllm_config .model_config .hf_config
653- model_config = vllm_config .model_config
654- cache_config = vllm_config .cache_config
655737 quant_config = vllm_config .quant_config
656- enable_eplb = vllm_config .parallel_config .enable_eplb
657738 self .config = config
658739
659740 self .vocab_size = config .vocab_size
@@ -669,14 +750,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
669750
670751 self .start_layer , self .end_layer , self .layers = make_layers (
671752 config .num_hidden_layers ,
672- lambda prefix : DeepseekV2DecoderLayer (
673- config ,
674- prefix ,
675- model_config = model_config ,
676- cache_config = cache_config ,
677- quant_config = quant_config ,
678- enable_eplb = enable_eplb ,
679- ),
753+ lambda prefix : DeepseekV2DecoderLayer (vllm_config , prefix ),
680754 prefix = f"{ prefix } .layers" )
681755
682756 if get_pp_group ().is_last_rank :
0 commit comments