7272from vllm_ascend .multistream .context import (set_multistream_context ,get_multistream_layer_context ,
7373 advance_step_multistream_layer_context , get_multistream_comm_context )
7474from vllm_ascend .multistream .layers import (MultiStreamPreTransformerLayer , MultiStreamPostTransformerLayer )
75- from vllm_ascend .multistream .metadata import make_multistream_metadata_ds , MultiStreamStepMetadata
75+ from vllm_ascend .multistream .metadata import make_multistream_metadata_ds , MultiStreamStepMetadata , MultiStreamConfig
7676from vllm_ascend .multistream .base import MSEventKey
7777from vllm_ascend .multistream .ms_split import compute_split_seq_index
7878
7979VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
80+ VLLM_ENABLE_MS : bool = envs_ascend .VLLM_ENABLE_MS
8081
8182
8283class CustomDeepseekV2MLP (nn .Module ):
@@ -305,8 +306,10 @@ def _forward_ms_op_tp_allreduce(
305306 dist .all_gather (list (chunk_hidden_states ), hidden_states ,
306307 self .tp_group )
307308 final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
308- if num_tokens < self .tp_size :
309- final_hidden_states = final_hidden_states [:num_tokens ]
309+ #if num_tokens < self.tp_size:
310+ # final_hidden_states = final_hidden_states[:num_tokens]
311+ if num_tokens > 0 :
312+ final_hidden_states = final_hidden_states [:- num_tokens ]
310313 else :
311314 final_hidden_states = hidden_states
312315
@@ -641,6 +644,10 @@ def _forward_ms_layer(
641644 )
642645
643646 with set_multistream_context (context , i ):
647+ context = get_forward_context ()
648+ layer_index , ms_metadata , attn_metadata = get_multistream_layer_context ()
649+ context .attn_metadata = attn_metadata [i ]
650+
644651 # input layernorm
645652 hidden_states [i ], residual [i ] = self ._forward_ms_op_input_layernorm (hidden_states [i ], residual [i ])
646653 # attention and tp allreducea
@@ -664,7 +671,7 @@ def _forward_ms_layer(
664671
665672 num_token , hidden_dim = hidden_states [i ].shape
666673 hidden_states [i ] = hidden_states [i ].view (- 1 , hidden_dim )
667- num_tokens .append (num_token )
674+ # num_tokens.append(num_token)
668675 hidden_dims .append (hidden_dim )
669676 if self .mlp .n_shared_experts is not None :
670677 # TODO: we can move shared expert computation into next block if reduce results is false
@@ -686,13 +693,20 @@ def _forward_ms_layer(
686693 enable_force_load_balance = False
687694
688695 if self .mlp .tp_size > 1 :
689- if num_tokens [i ] < self .mlp .tp_size :
690- target_size = self .mlp .tp_size
691- new_hidden_states = torch .empty ([target_size , hidden_dims [i ]],
692- dtype = hidden_states [i ].dtype ,
693- device = hidden_states [i ].device )
694- new_hidden_states [:num_tokens [i ]] = hidden_states [i ]
695- hidden_states [i ] = new_hidden_states
696+ #if num_tokens[i] < self.mlp.tp_size:
697+ # target_size = self.mlp.tp_size
698+ # new_hidden_states = torch.empty([target_size, hidden_dims[i]],
699+ # dtype=hidden_states[i].dtype,
700+ # device=hidden_states[i].device)
701+ # new_hidden_states[:num_tokens[i]] = hidden_states[i]
702+ # hidden_states[i] = new_hidden_states
703+ num_token , _ = hidden_states [i ].shape
704+ padded_num_tokens = (self .mlp .tp_size -
705+ num_token % self .mlp .tp_size ) % self .mlp .tp_size
706+ if padded_num_tokens > 0 :
707+ hidden_states [i ] = nn .functional .pad (hidden_states [i ],
708+ (0 , 0 , 0 , padded_num_tokens ))
709+ num_tokens .append (padded_num_tokens )
696710 chunk_hidden_state = torch .tensor_split (hidden_states [i ],
697711 self .mlp .tp_size ,
698712 dim = 0 )
@@ -713,7 +727,7 @@ def _forward_ms_layer(
713727 if VLLM_ENABLE_MC2 and not is_prefill :
714728 ...
715729
716- hidden_states [i ] = self .mlp .experts ._forward_ms_fused_moe_comp (hidden_states [ i ] , router_logits [i ], is_prefill , real_top_k , enable_force_load_balance )
730+ hidden_states [i ] = self .mlp .experts ._forward_ms_fused_moe_comp (local_hidden_states , router_logits [i ], is_prefill , real_top_k , enable_force_load_balance )
717731
718732 if VLLM_ENABLE_MC2 and not is_prefill :
719733 ...
@@ -847,7 +861,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
847861 ["hidden_states" , "residual" ], config .hidden_size ))
848862
849863 # tbo related members
850- self .multistream_config = vllm_config .model_config .multistream_config
864+ if VLLM_ENABLE_MS :
865+ self .multistream_config = MultiStreamConfig ()
866+ else :
867+ self .multistream_config = None
851868 self .use_mla = model_config .use_mla
852869 self .multistream_metadata = make_multistream_metadata_ds (
853870 start_layer = self .start_layer + self .first_k_dense_replace ,
@@ -929,13 +946,14 @@ def can_run_ms(self):
929946 return False
930947 num_microbatchs = self .multistream_config .num_micro_batches
931948 # check whether there is a dp rank that not use dual batch
932- if dp_metadata is not None :
949+ ''' if dp_metadata is not None:
933950 for i in range(num_microbatchs):
934951 cu_tokens = dp_metadata.cu_dbo_tokens_across_dp_cpu[i]
935952 if torch.any(cu_tokens == 0).item():
936953 return False
937954 [token_index, seq_index] = compute_split_seq_index(attn_metadata.query_lens,
938- attn_metadata .attn_state , attn_metadata .num_decode_tokens )
955+ attn_metadata.attn_state, attn_metadata.num_decode_tokens)
956+ '''
939957 if token_index == 0 or seq_index == 0 or seq_index == len (attn_metadata .query_lens ):
940958 return False
941959 # check whether the total tokens exceed the threshold
0 commit comments