@@ -149,6 +149,51 @@ def forward(self, x):
149149 x , _ = self .down_proj (x )
150150 return x
151151
152+ def _forward_ms_mlp (self , x ):
153+ current_ms_metadata = get_multistream_comm_context ()
154+ assert current_ms_metadata is not None
155+ if self .is_dynamic_quant :
156+ x , dynamic_scale = torch_npu .npu_dynamic_quant (x )
157+ x = torch_npu .npu_quant_matmul (
158+ x ,
159+ self .gate_up_proj .weight ,
160+ self .gate_up_proj .weight_scale ,
161+ output_dtype = torch .int32 ,
162+ )
163+ x , dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
164+ x = x ,
165+ weight_scale = self .gate_up_proj .weight_scale_fp32 ,
166+ activation_scale = dynamic_scale ,
167+ bias = None ,
168+ quant_scale = None ,
169+ quant_offset = None ,
170+ group_index = None ,
171+ activate_left = True ,
172+ quant_mode = 1 )
173+ x = torch_npu .npu_quant_matmul (
174+ x ,
175+ self .down_proj .weight ,
176+ self .down_proj .weight_scale ,
177+ pertoken_scale = dynamic_scale ,
178+ output_dtype = torch .bfloat16 ,
179+ )
180+ if self .down_proj .reduce_results and self .down_proj .tp_size > 1 :
181+ current_ms_metadata .before_comm_event .record ()
182+ with torch .npu .stream (current_ms_metadata .comm_stream ):
183+ current_ms_metadata .before_comm_event .wait ()
184+ x = tensor_model_parallel_all_reduce (x )
185+ current_ms_metadata .after_comm_event .record ()
186+ return x
187+ gate_up , _ = self .gate_up_proj (x )
188+ x = self .act_fn (gate_up )
189+ current_ms_metadata .before_comm_event .record ()
190+ with torch .npu .stream (current_ms_metadata .comm_stream ):
191+ current_ms_metadata .before_comm_event .wait ()
192+ x , _ = self .down_proj (x )
193+ current_ms_metadata .after_comm_event .record ()
194+
195+ return x
196+
152197
153198class CustomDeepseekV2MoE (nn .Module ):
154199
@@ -282,7 +327,7 @@ def _forward_ms_op_shared_expert(
282327 self ,
283328 hidden_states : torch .Tensor ,
284329 ):
285- shared_output = self .shared_experts (hidden_states )
330+ shared_output = self .shared_experts . _forward_ms_mlp (hidden_states )
286331 return shared_output
287332
288333 def _forward_ms_op_gate (
@@ -293,7 +338,7 @@ def _forward_ms_op_gate(
293338 router_logits , _ = self .gate (hidden_states )
294339 return router_logits
295340
296- def _forward_ms_op_tp_allreduce (
341+ def _forward_ms_op_tp_allgather (
297342 self ,
298343 hidden_states : torch .Tensor ,
299344 shared_output : torch .Tensor ,
@@ -303,13 +348,26 @@ def _forward_ms_op_tp_allreduce(
303348 ):
304349
305350 if self .tp_size > 1 :
306- dist .all_gather (list (chunk_hidden_states ), hidden_states ,
307- self .tp_group )
308- final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
309- #if num_tokens < self.tp_size:
310- # final_hidden_states = final_hidden_states[:num_tokens]
311- if num_tokens > 0 :
312- final_hidden_states = final_hidden_states [:- num_tokens ]
351+ current_ms_metadata = get_multistream_comm_context ()
352+ if current_ms_metadata is None :
353+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
354+ self .tp_group )
355+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
356+ #if num_tokens < self.tp_size:
357+ # final_hidden_states = final_hidden_states[:num_tokens]
358+ if num_tokens > 0 :
359+ final_hidden_states = final_hidden_states [:- num_tokens ]
360+ else :
361+ current_ms_metadata .before_comm_event .record ()
362+ with torch .npu .stream (current_ms_metadata .comm_stream ):
363+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
364+ self .tp_group )
365+ final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
366+ #if num_tokens < self.tp_size:
367+ # final_hidden_states = final_hidden_states[:num_tokens]
368+ if num_tokens > 0 :
369+ final_hidden_states = final_hidden_states [:- num_tokens ]
370+
313371 else :
314372 final_hidden_states = hidden_states
315373
@@ -650,25 +708,24 @@ def _forward_ms_layer(
650708
651709 # input layernorm
652710 hidden_states [i ], residual [i ] = self ._forward_ms_op_input_layernorm (hidden_states [i ], residual [i ])
653- # attention and tp allreducea
711+ # attention and tp allreduce
654712 hidden_states [i ], residual [i ] = self ._forward_ms_op_attn (positions [i ], hidden_states [i ], residual [i ], kv_cache , attn_metadata [i ])
655713
656714 ''' block 3 : shared experts
657715 if there is an allreduce ops in shared expert, we can overlap it with the computation of the
658716 shared expert for next microbatch or moe gating
659717 '''
660718 for i in range (num_micro_batchs ):
719+ ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
661720 context = MultiStreamStepMetadata (
662721 comm_stream = ms_metadata .communicate_stream ,
663722 before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMP_FINISH ],
664723 after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMM_FINISH ],
665724 )
666725 with set_multistream_context (context , i ):
667726 # compute shared expert after finishing ATTN AR
668- ms_metadata .try_wait_event (layer_index , i , MSEventKey .ATTN_AR_FINISH )
669727 hidden_states [i ], residual [i ] = self ._forward_ms_op_post_attn_layernorm (hidden_states [i ], residual [i ])
670728
671-
672729 num_token , hidden_dim = hidden_states [i ].shape
673730 hidden_states [i ] = hidden_states [i ].view (- 1 , hidden_dim )
674731 #num_tokens.append(num_token)
@@ -740,10 +797,14 @@ def _forward_ms_layer(
740797 before_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_COM_FINISH ],
741798 after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_AFTER_COMM ],
742799 )
743- with set_multistream_context (context , i ):
800+ context .before_comm_event .record ()
801+ with torch .npu .stream (ms_metadata .communicate_stream ):
802+ #with set_multistream_context(context, i):
803+ context .before_comm_event .wait ()
744804 if self .mlp .experts .reduce_results and (self .mlp .experts .tp_size > 1 or self .mlp .experts .ep_size > 1 ):
745805 hidden_states [i ] = tensor_model_parallel_all_reduce (
746806 hidden_states [i ])
807+ context .after_comm_event .record ()
747808 # check here
748809 hidden_states [i ] = hidden_states [i ] * self .mlp .routed_scaling_factor
749810 context = MultiStreamStepMetadata (
@@ -752,7 +813,7 @@ def _forward_ms_layer(
752813 after_comm_event = ms_metadata .ms_events [layer_index ][i ][MSEventKey .FFN_AR_FINISH ],
753814 )
754815 with set_multistream_context (context , i ):
755- hidden_states [i ] = self .mlp ._forward_ms_op_tp_allreduce (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
816+ hidden_states [i ] = self .mlp ._forward_ms_op_tp_allgather (hidden_states [i ], shared_outputs [i ], chunk_hidden_states [i ], num_tokens [i ], hidden_dims [i ])
756817 with torch .npu .stream (ms_metadata .communicate_stream ):
757818 # last
758819 if isinstance (
@@ -764,6 +825,7 @@ def _forward_ms_layer(
764825 # The scaling of DeepseekV2MOE output would be done in the forward
765826 # of DeepseekV2MOE
766827 hidden_states [i ] *= 1. / self .routed_scaling_factor
828+ context .after_comm_event .record ()
767829 return hidden_states , residual
768830 # should split ops in Decoder Layer
769831 def _forward_ms_op_input_layernorm (
@@ -951,9 +1013,9 @@ def can_run_ms(self):
9511013 cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
9521014 if torch.any(cu_tokens == 0).item():
9531015 return False
1016+ '''
9541017 [token_index , seq_index ] = compute_split_seq_index (attn_metadata .query_lens ,
955- attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956- '''
1018+ attn_metadata .attn_state , attn_metadata .num_decode_tokens )
9571019 if token_index == 0 or seq_index == 0 or seq_index == len (attn_metadata .query_lens ):
9581020 return False
9591021 # check whether the total tokens exceed the threshold
0 commit comments