7474from vllm_ascend .multistream .layers import (MultiStreamPostTransformerLayer ,
7575 MultiStreamPreTransformerLayer )
7676from vllm_ascend .multistream .metadata import (MultiStreamConfig ,
77+ MultiStreamMetadata ,
7778 MultiStreamStepMetadata ,
7879 make_multistream_metadata_ds )
7980from vllm_ascend .multistream .ms_split import compute_split_seq_index
@@ -698,13 +699,12 @@ def _forward_ms_layer(
698699 shared_outputs = []
699700 router_logits = []
700701 chunk_hidden_states = []
701- ''' block 1 : attention
702- block 2 : attn tp communication, currently we switch to the comm stream
703- in tensor_model_parallel_all_reduce;
704- the attn computation of microbatch 1 can be overlapped with the moe
705- communication in the previous layer, and the attn computation of microbatch
706- 2 can be overlapped with the attn communication of microbatch 1
707- '''
702+
703+ # block 1 : attention
704+ # block 2 : attn tp communication
705+ # the attn computation of microbatch 1 can be overlapped with the moe
706+ # communication in the previous layer, and the attn computation of microbatch 2
707+ # can be overlapped with the attn communication of microbatch 1
708708 for i in range (num_micro_batchs ):
709709 # wait last layer moe finishing communication
710710 ms_metadata .try_wait_event (layer_index - 1 , i ,
@@ -731,10 +731,10 @@ def _forward_ms_layer(
731731 hidden_states [i ], residual [i ] = self ._forward_ms_op_attn (
732732 positions [i ], hidden_states [i ], residual [i ], kv_cache ,
733733 attn_metadata [i ])
734- ''' block 3 : shared experts
735- if there is an allreduce ops in shared expert, we can overlap it with the computation of the
736- shared expert for next microbatch or moe gating
737- '''
734+
735+ # block 3 : shared experts
736+ # if there is an allreduce ops in shared expert, we can overlap it with the computation of the
737+ # shared expert for next microbatch or moe gating
738738 for i in range (num_micro_batchs ):
739739 ms_metadata .try_wait_event (layer_index , i ,
740740 MSEventKey .ATTN_AR_FINISH )
@@ -763,7 +763,6 @@ def _forward_ms_layer(
763763
764764 # block 4 : moe
765765 for i in range (num_micro_batchs ):
766- #ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
767766 # when profile runs, force experts to load balanced tokens
768767 # to avoid high memory consumption on a single rank.
769768 # TODO: need a better flag to indicate whether in profile run or not.
@@ -776,13 +775,6 @@ def _forward_ms_layer(
776775 enable_force_load_balance = False
777776
778777 if self .mlp .tp_size > 1 :
779- #if num_tokens[i] < self.mlp.tp_size:
780- # target_size = self.mlp.tp_size
781- # new_hidden_states = torch.empty([target_size, hidden_dims[i]],
782- # dtype=hidden_states[i].dtype,
783- # device=hidden_states[i].device)
784- # new_hidden_states[:num_tokens[i]] = hidden_states[i]
785- # hidden_states[i] = new_hidden_states
786778 num_token , _ = hidden_states [i ].shape
787779 padded_num_tokens = (self .mlp .tp_size - num_token %
788780 self .mlp .tp_size ) % self .mlp .tp_size
@@ -805,18 +797,12 @@ def _forward_ms_layer(
805797 else :
806798 real_top_k = self .mlp .experts .top_k
807799
808- if VLLM_ENABLE_MC2 and not is_prefill :
809- ...
810-
811800 hidden_states [i ] = self .mlp .experts ._forward_ms_fused_moe_comp (
812801 local_hidden_states , router_logits [i ], is_prefill , real_top_k ,
813802 enable_force_load_balance )
814803
815- if VLLM_ENABLE_MC2 and not is_prefill :
816- ...
817- ''' the following kernels will be submitted to the comm stream to overlap the computation of the
818- moe computation of next microbatch and the attn computation of next layer
819- '''
804+ # the following kernels will be submitted to the comm stream to overlap the computation of the
805+ # moe computation of next microbatch and the attn computation of next layer
820806 context = MultiStreamStepMetadata (
821807 comm_stream = ms_metadata .communicate_stream ,
822808 before_comm_event = ms_metadata .ms_events [layer_index ][i ][
@@ -826,15 +812,14 @@ def _forward_ms_layer(
826812 )
827813 context .before_comm_event .record ()
828814 with torch .npu .stream (ms_metadata .communicate_stream ):
829- #with set_multistream_context(context, i):
830815 context .before_comm_event .wait ()
831816 if self .mlp .experts .reduce_results and (
832817 self .mlp .experts .tp_size > 1
833818 or self .mlp .experts .ep_size > 1 ):
834819 hidden_states [i ] = tensor_model_parallel_all_reduce (
835820 hidden_states [i ])
836821 context .after_comm_event .record ()
837- # check here
822+
838823 hidden_states [
839824 i ] = hidden_states [i ] * self .mlp .routed_scaling_factor
840825 context = MultiStreamStepMetadata (
@@ -959,21 +944,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
959944 ["hidden_states" , "residual" ], config .hidden_size ))
960945
961946 # tbo related members
962- self .multistream_config : Optional [MultiStreamConfig ] = None
963947 if VLLM_ENABLE_DBO :
948+ self .use_mla = model_config .use_mla
964949 self .multistream_config = MultiStreamConfig ()
965-
966- self .use_mla = model_config .use_mla
967- self .multistream_metadata = make_multistream_metadata_ds (
968- start_layer = self .start_layer + self .first_k_dense_replace ,
969- end_layer = self .end_layer ,
970- causal_lm = getattr (config , "causal_lm" , True ),
971- multistream_config = self .multistream_config ,
972- )
973- self .ms_pre_layer = MultiStreamPreTransformerLayer (
974- self .multistream_metadata )
975- self .ms_post_layer = MultiStreamPostTransformerLayer (
976- self .multistream_metadata )
950+ multistream_metadata = make_multistream_metadata_ds (
951+ start_layer = self .start_layer + self .first_k_dense_replace ,
952+ end_layer = self .end_layer ,
953+ causal_lm = getattr (config , "causal_lm" , True ),
954+ multistream_config = self .multistream_config ,
955+ )
956+ self .ms_pre_layer = MultiStreamPreTransformerLayer (
957+ multistream_metadata )
958+ self .ms_post_layer = MultiStreamPostTransformerLayer (
959+ multistream_metadata )
977960
978961 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
979962 return self .embed_tokens (input_ids )
@@ -998,11 +981,10 @@ def forward(
998981 hidden_states = intermediate_tensors ["hidden_states" ]
999982 residual = intermediate_tensors ["residual" ]
1000983
1001- num_normal_layers = (self .first_k_dense_replace
1002- if self .multistream_config is not None
984+ num_normal_layers = (self .first_k_dense_replace if VLLM_ENABLE_DBO
1003985 and self .can_run_ms () else self .end_layer -
1004986 self .start_layer )
1005- # if we enable multistream/dbo, only process dense layers here
987+
1006988 for i in range (self .start_layer , self .start_layer + num_normal_layers ):
1007989 layer = self .layers [i ]
1008990 hidden_states , residual = layer (
@@ -1012,13 +994,15 @@ def forward(
1012994 attn_metadata )
1013995
1014996 moe_start_layer = self .start_layer + num_normal_layers
1015- hidden_states , residual = self ._forward_ms_layers (
1016- positions = positions ,
1017- hidden_states = hidden_states ,
1018- residual = residual ,
1019- moe_start_layer = moe_start_layer ,
1020- kv_caches = kv_caches ,
1021- )
997+ if moe_start_layer != self .end_layer :
998+ # if we enable multistream/dbo, process sparse layers here
999+ hidden_states , residual = self ._forward_ms_layers (
1000+ positions = positions ,
1001+ hidden_states = hidden_states ,
1002+ residual = residual ,
1003+ moe_start_layer = moe_start_layer ,
1004+ kv_caches = kv_caches ,
1005+ )
10221006
10231007 if not get_pp_group ().is_last_rank :
10241008 return IntermediateTensors ({
@@ -1045,11 +1029,8 @@ def can_run_ms(self):
10451029 if token_index == 0 or seq_index == 0 or seq_index == len (
10461030 attn_metadata .query_lens ):
10471031 return False
1048-
1049- if self .multistream_config is None :
1050- return False
10511032 # check whether the total tokens exceed the threshold
1052- if attn_metadata .num_actual_tokens < self .multistream_config .min_total_tokens_to_split :
1033+ if self . multistream_config is None or attn_metadata .num_actual_tokens < self .multistream_config .min_total_tokens_to_split :
10531034 return False
10541035 return True
10551036
0 commit comments