3535from vllm .attention import Attention , AttentionMetadata
3636from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
3737 get_current_vllm_config )
38+ # Temporarily disable yapf since it conflicts with isort.
39+ # yapf: disable
3840from vllm .distributed import (get_dp_group , get_pp_group ,
41+ get_tensor_model_parallel_rank ,
3942 get_tensor_model_parallel_world_size ,
40- get_tp_group )
43+ get_tp_group , split_tensor_along_last_dim ,
44+ tensor_model_parallel_all_gather ,
45+ tensor_model_parallel_all_reduce ,
46+ tensor_model_parallel_reduce_scatter )
47+ # yapf: enable
4148from vllm .forward_context import get_forward_context
4249from vllm .model_executor .layers .activation import SiluAndMul
4350from vllm .model_executor .layers .layernorm import RMSNorm
@@ -132,6 +139,80 @@ def weight_loader(self, param: torch.nn.Parameter,
132139 shard .copy_ (loaded_weight )
133140
134141
142+ class CustomDeepseekV2RowParallelLinearReplaceAllreduce (RowParallelLinear ):
143+
144+ def forward (
145+ self ,
146+ input_ ,
147+ is_prefill = True
148+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [nn .Parameter ]]]:
149+ if self .input_is_parallel :
150+ input_parallel = input_
151+ else :
152+ tp_rank = get_tensor_model_parallel_rank ()
153+ splitted_input = split_tensor_along_last_dim (
154+ input_ , num_partitions = self .tp_size )
155+ input_parallel = splitted_input [tp_rank ].contiguous ()
156+
157+ # Matrix multiply.
158+ assert self .quant_method is not None
159+ # Only fuse bias add into GEMM for rank 0 (this ensures that
160+ # bias will not get added more than once in TP>1 case)
161+ bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
162+ output_parallel = self .quant_method .apply (self ,
163+ input_parallel ,
164+ bias = bias_ )
165+ if self .reduce_results and self .tp_size > 1 :
166+ if not is_prefill and output_parallel .shape [0 ] % self .tp_size == 0 :
167+ output = tensor_model_parallel_reduce_scatter (output_parallel ,
168+ dim = 0 )
169+ else :
170+ output = tensor_model_parallel_all_reduce (output_parallel )
171+ else :
172+ output = output_parallel
173+
174+ output_bias = self .bias if self .skip_bias_add else None
175+
176+ if not self .return_bias :
177+ return output
178+ return output , output_bias
179+
180+
181+ class CustomDeepseekV2RowParallelLinear (RowParallelLinear ):
182+
183+ def forward (
184+ self ,
185+ input_ ,
186+ is_prefill = True
187+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [nn .Parameter ]]]:
188+ if self .input_is_parallel :
189+ input_parallel = input_
190+ else :
191+ tp_rank = get_tensor_model_parallel_rank ()
192+ splitted_input = split_tensor_along_last_dim (
193+ input_ , num_partitions = self .tp_size )
194+ input_parallel = splitted_input [tp_rank ].contiguous ()
195+
196+ # Matrix multiply.
197+ assert self .quant_method is not None
198+ # Only fuse bias add into GEMM for rank 0 (this ensures that
199+ # bias will not get added more than once in TP>1 case)
200+ bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
201+ output_parallel = self .quant_method .apply (self ,
202+ input_parallel ,
203+ bias = bias_ )
204+ if self .reduce_results and self .tp_size > 1 :
205+ output = tensor_model_parallel_all_reduce (output_parallel )
206+ else :
207+ output = output_parallel
208+
209+ output_bias = self .bias if self .skip_bias_add else None
210+
211+ if not self .return_bias :
212+ return output
213+ return output , output_bias
214+
215+
135216class CustomDeepseekV2MLP (nn .Module ):
136217
137218 def __init__ (
@@ -291,10 +372,10 @@ def __init__(
291372
292373 self .params_dtype = torch .get_default_dtype ()
293374
294- def forward (
295- self ,
296- hidden_states : torch . Tensor ,
297- attn_metadata : Optional [ AttentionMetadata ] = None ) -> torch .Tensor :
375+ def forward (self ,
376+ hidden_states : torch . Tensor ,
377+ attn_metadata : Optional [ AttentionMetadata ] = None ,
378+ replace_allreduce : bool = False ) -> torch .Tensor :
298379 forward_context = get_forward_context ()
299380 if attn_metadata is None :
300381 attn_metadata = forward_context .attn_metadata
@@ -323,7 +404,7 @@ def forward(
323404 enable_force_load_balance = enable_force_load_balance ,
324405 shared_experts = self .shared_experts ,
325406 gate = self .gate if self .enable_multistream_moe else None ,
326- )
407+ replace_allreduce = replace_allreduce )
327408
328409 hidden_states = (
329410 experts_hidden_states [0 ] * self .routed_scaling_factor +
@@ -370,6 +451,14 @@ def __init__(
370451 self .rope_theta = rope_theta
371452 self .max_position_embeddings = max_position_embeddings
372453
454+ self .prefix = prefix
455+ self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
456+
457+ ascend_config = get_ascend_config ()
458+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
459+ self .enable_multistream_mla = \
460+ ascend_config .torchair_graph_config .enable_multistream_mla
461+
373462 if self .q_lora_rank is not None :
374463 self .q_a_proj = ReplicatedLinear (self .hidden_size ,
375464 self .q_lora_rank ,
@@ -406,11 +495,23 @@ def __init__(
406495 bias = False ,
407496 quant_config = quant_config ,
408497 prefix = f"{ prefix } .kv_b_proj" )
409- self .o_proj = RowParallelLinear (self .num_heads * self .v_head_dim ,
410- self .hidden_size ,
411- bias = False ,
412- quant_config = quant_config ,
413- prefix = f"{ prefix } .o_proj" )
498+ if (config .n_routed_experts is not None
499+ and self .debug_layer_idx >= config .first_k_dense_replace
500+ and self .debug_layer_idx % config .moe_layer_freq == 0 and
501+ ascend_config .torchair_graph_config .enable_multistream_moe ):
502+ self .o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce (
503+ self .num_heads * self .v_head_dim ,
504+ self .hidden_size ,
505+ bias = False ,
506+ quant_config = quant_config ,
507+ prefix = f"{ prefix } .o_proj" )
508+ else :
509+ self .o_proj = CustomDeepseekV2RowParallelLinear (
510+ self .num_heads * self .v_head_dim ,
511+ self .hidden_size ,
512+ bias = False ,
513+ quant_config = quant_config ,
514+ prefix = f"{ prefix } .o_proj" )
414515
415516 if rope_scaling :
416517 rope_scaling ["rope_type" ] = 'deepseek_yarn'
@@ -456,14 +557,6 @@ def __init__(
456557 o_proj = self .o_proj ,
457558 )
458559
459- self .prefix = prefix
460- self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
461-
462- ascend_config = get_ascend_config ()
463- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
464- self .enable_multistream_mla = \
465- ascend_config .torchair_graph_config .enable_multistream_mla
466-
467560 def forward (
468561 self ,
469562 positions : torch .Tensor ,
@@ -530,6 +623,10 @@ def __init__(
530623 # with the layer's index.
531624 layer_idx = int (prefix .split (sep = '.' )[- 1 ])
532625 self .layer_idx = layer_idx
626+ self .layers = config .num_hidden_layers
627+ self .tp_size = get_tensor_model_parallel_world_size ()
628+ self .tp_rank = get_tp_group ().rank_in_group
629+ ascend_config = get_ascend_config ()
533630 # TODO: enable mla in vllm-ascend
534631 if model_config .use_mla :
535632 attn_cls = CustomDeepseekV2MLAAttention
@@ -561,6 +658,8 @@ def __init__(
561658 quant_config = quant_config ,
562659 prefix = f"{ prefix } .mlp" ,
563660 )
661+ self .mla_moe_communication = ascend_config .torchair_graph_config .enable_multistream_moe \
662+ and model_config .use_mla and envs .VLLM_USE_V1 and self .tp_size > 1
564663 else :
565664 self .mlp = CustomDeepseekV2MLP (
566665 hidden_size = config .hidden_size ,
@@ -569,11 +668,13 @@ def __init__(
569668 quant_config = quant_config ,
570669 prefix = f"{ prefix } .mlp" ,
571670 )
671+ self .mla_moe_communication = False
572672 self .input_layernorm = RMSNorm (config .hidden_size ,
573673 eps = config .rms_norm_eps )
574674 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
575675 eps = config .rms_norm_eps )
576676 self .routed_scaling_factor = config .routed_scaling_factor
677+ self .first_k_dense_replace = config .first_k_dense_replace
577678
578679 def forward (
579680 self ,
@@ -582,8 +683,13 @@ def forward(
582683 residual : Optional [torch .Tensor ],
583684 kv_cache : Optional [torch .Tensor ] = None ,
584685 attn_metadata : Optional [AttentionMetadata ] = None ,
686+ replace_allreduce : bool = False ,
585687 ) -> torch .Tensor :
586688 # Self Attention
689+ if attn_metadata is not None and attn_metadata .num_decodes > 0 :
690+ mla_moe_communication = self .mla_moe_communication and replace_allreduce
691+ else :
692+ mla_moe_communication = False
587693 if residual is None :
588694 residual = hidden_states
589695 hidden_states = self .input_layernorm (hidden_states )
@@ -595,6 +701,9 @@ def forward(
595701 # to save npu memory because they're no longer used.
596702 dispose_tensor (previous_hidden_states )
597703 dispose_tensor (previous_residual )
704+ if mla_moe_communication and self .layer_idx > self .first_k_dense_replace :
705+ hidden_states = tensor_model_parallel_all_gather (hidden_states ,
706+ dim = 0 )
598707
599708 hidden_states = self .self_attn (
600709 positions = positions ,
@@ -603,6 +712,13 @@ def forward(
603712 attn_metadata = attn_metadata ,
604713 )
605714
715+ if mla_moe_communication and residual .shape [0 ] != hidden_states .shape [
716+ 0 ]:
717+ chunk_hidden_states = torch .tensor_split (residual ,
718+ self .tp_size ,
719+ dim = 0 )
720+ residual = chunk_hidden_states [self .tp_rank ]
721+
606722 if hidden_states .dtype == torch .float16 :
607723 # Fix FP16 overflow
608724 # We scale both hidden_states and residual before
@@ -618,7 +734,9 @@ def forward(
618734 hidden_states , residual )
619735
620736 if isinstance (self .mlp , CustomDeepseekV2MoE ):
621- hidden_states = self .mlp (hidden_states , attn_metadata )
737+ hidden_states = self .mlp (hidden_states ,
738+ attn_metadata ,
739+ replace_allreduce = mla_moe_communication )
622740 else :
623741 hidden_states = self .mlp (hidden_states )
624742
@@ -631,6 +749,10 @@ def forward(
631749 # The scaling of DeepseekV2MOE output would be done in the forward
632750 # of DeepseekV2MOE
633751 hidden_states *= 1. / self .routed_scaling_factor
752+ if mla_moe_communication and self .layer_idx == self .layers - 1 :
753+ hidden_states = tensor_model_parallel_all_gather (hidden_states ,
754+ dim = 0 )
755+ residual = tensor_model_parallel_all_gather (residual , dim = 0 )
634756
635757 return hidden_states , residual
636758
@@ -649,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
649771
650772 self .padding_idx = config .pad_token_id
651773 self .vocab_size = config .vocab_size
774+ self .tp_size = get_tensor_model_parallel_world_size ()
652775
653776 if get_pp_group ().is_first_rank :
654777 self .embed_tokens = VocabParallelEmbedding (
@@ -701,13 +824,18 @@ def forward(
701824 hidden_states = intermediate_tensors ["hidden_states" ]
702825 residual = intermediate_tensors ["residual" ]
703826
827+ replace_allreduce = hidden_states .shape [0 ] % self .tp_size == 0
828+
704829 for i in range (self .start_layer , self .end_layer ):
705830 layer = self .layers [i ]
706831 hidden_states , residual = layer (
707- positions , hidden_states , residual ,
832+ positions ,
833+ hidden_states ,
834+ residual ,
708835 kv_caches [i -
709836 self .start_layer ] if kv_caches is not None else None ,
710- attn_metadata )
837+ attn_metadata ,
838+ replace_allreduce = replace_allreduce )
711839
712840 if not get_pp_group ().is_last_rank :
713841 return IntermediateTensors ({
0 commit comments