2626# """Inference-only DeepseekV2/DeepseekV3 model.""" 
2727
2828import  os 
29- from  typing  import  Any , Dict , Optional , Union 
29+ from  typing  import  Any , Dict , List ,  Optional , Union 
3030
3131import  torch 
3232import  torch .distributed  as  dist 
3333from  torch  import  nn 
3434from  transformers  import  PretrainedConfig 
35- from  vllm .attention  import  Attention 
35+ from  vllm .attention  import  Attention ,  AttentionMetadata 
3636from  vllm .config  import  (CacheConfig , ModelConfig , VllmConfig ,
3737                         get_current_vllm_config )
3838from  vllm .distributed  import  (get_dp_group , get_pp_group ,
6464from  vllm .sequence  import  IntermediateTensors 
6565
6666from  vllm_ascend .ops .fused_moe  import  AscendFusedMoE 
67- from  vllm_ascend .utils  import  VLLM_ENABLE_GRAPH_MODE 
6867
6968
7069class  CustomDeepseekV2MoE (nn .Module ):
@@ -133,7 +132,7 @@ def __init__(
133132        vllm_config  =  get_current_vllm_config ()
134133        self .dp_size  =  get_dp_group ().world_size 
135134        batch_size  =  vllm_config .scheduler_config .max_num_seqs 
136-         self .enable_mc2  =  int (os .environ .get ("VLLM_ENABLE_MC2" , 0 )) ==  1 
135+         self .enable_mc2  =  int (os .environ .get ("VLLM_ENABLE_MC2" , '0' )) ==  1 
137136
138137        params_dtype  =  torch .get_default_dtype ()
139138        self .final_hidden_states  =  torch .zeros (
@@ -309,38 +308,36 @@ def __init__(
309308
310309        self .prefix  =  prefix 
311310        self .debug_layer_idx  =  int (self .prefix .split ("." )[- 2 ])
312-         if  VLLM_ENABLE_GRAPH_MODE  ==  "1" :
313-             self .forward  =  self .forward_torchair 
314-         else :
315-             self .forward  =  self .forward_eager   # type: ignore 
311+         self .enable_graph_mode  =  False 
312+         additional_config  =  get_current_vllm_config ().additional_config 
313+         if  additional_config :
314+             self .enable_graph_mode  =  additional_config .get (
315+                 "enable_graph_mode" , False )
316316
317-     def  forward_torchair (self ,
318-                          positions : torch .Tensor ,
319-                          hidden_states : torch .Tensor ,
320-                          kv_cache : torch .Tensor  =  None ,
321-                          attn_metadata = None ):
317+     def  forward (
318+             self ,
319+             positions : torch .Tensor ,
320+             hidden_states : torch .Tensor ,
321+             kv_cache : Optional [torch .Tensor ] =  None ,
322+             attn_metadata : Optional [AttentionMetadata ] =  None ) ->  torch .Tensor :
322323        if  self .q_lora_rank  is  not None :
323324            ckq  =  self .q_a_proj (hidden_states )[0 ]
324325            hidden_states_or_q_c  =  self .q_a_layernorm (ckq )
325326        else :
326327            hidden_states_or_q_c  =  hidden_states 
327-         return  self .mla_attn (hidden_states_or_q_c , hidden_states , None ,
328-                              kv_cache , attn_metadata )
329- 
330-     def  forward_eager (self , positions : torch .Tensor ,
331-                       hidden_states : torch .Tensor ):
332-         if  self .q_lora_rank  is  not None :
333-             ckq  =  self .q_a_proj (hidden_states )[0 ]
334-             hidden_states_or_q_c  =  self .q_a_layernorm (ckq )
328+         if  self .enable_graph_mode :
329+             return  self .mla_attn .impl .forward (self .mla_attn ,
330+                                               hidden_states_or_q_c ,
331+                                               hidden_states , None , kv_cache ,
332+                                               attn_metadata )
335333        else :
336-             hidden_states_or_q_c  =  hidden_states 
337-         kv_c , k_pe  =  self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
338-             [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
339-         kv_c_normed  =  self .kv_a_layernorm (kv_c .contiguous ())
340-         return  self .mla_attn (hidden_states_or_q_c ,
341-                              kv_c_normed ,
342-                              k_pe ,
343-                              output_shape = hidden_states .shape )
334+             kv_c , k_pe  =  self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
335+                 [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
336+             kv_c_normed  =  self .kv_a_layernorm (kv_c .contiguous ())
337+             return  self .mla_attn (hidden_states_or_q_c ,
338+                                  kv_c_normed ,
339+                                  k_pe ,
340+                                  output_shape = hidden_states .shape )
344341
345342
346343class  CustomDeepseekV2DecoderLayer (DeepseekV2DecoderLayer ):
@@ -408,6 +405,54 @@ def __init__(
408405                                                eps = config .rms_norm_eps )
409406        self .routed_scaling_factor  =  config .routed_scaling_factor 
410407
408+     def  forward (
409+         self ,
410+         positions : torch .Tensor ,
411+         hidden_states : torch .Tensor ,
412+         residual : Optional [torch .Tensor ],
413+         kv_cache : Optional [torch .Tensor ] =  None ,
414+         attn_metadata : Optional [AttentionMetadata ] =  None ,
415+     ) ->  torch .Tensor :
416+         # Self Attention 
417+         if  residual  is  None :
418+             residual  =  hidden_states 
419+             hidden_states  =  self .input_layernorm (hidden_states )
420+         else :
421+             hidden_states , residual  =  self .input_layernorm (
422+                 hidden_states , residual )
423+         hidden_states  =  self .self_attn (
424+             positions = positions ,
425+             hidden_states = hidden_states ,
426+             kv_cache = kv_cache ,
427+             attn_metadata = attn_metadata ,
428+         )
429+ 
430+         if  hidden_states .dtype  ==  torch .float16 :
431+             # Fix FP16 overflow 
432+             # We scale both hidden_states and residual before 
433+             # rmsnorm, and rmsnorm result would not affect by scale. 
434+             hidden_states  *=  1.  /  self .routed_scaling_factor 
435+             if  self .layer_idx  ==  0 :
436+                 # The residual is shared by all layers, we only scale it on 
437+                 # first layer. 
438+                 residual  *=  1.  /  self .routed_scaling_factor 
439+ 
440+         # Fully Connected 
441+         hidden_states , residual  =  self .post_attention_layernorm (
442+             hidden_states , residual )
443+         hidden_states  =  self .mlp (hidden_states )
444+ 
445+         if  isinstance (self .mlp ,
446+                       DeepseekV2MLP ) and  hidden_states .dtype  ==  torch .float16 :
447+             # Fix FP16 overflow 
448+             # Scaling the DeepseekV2MLP output, it is the input of 
449+             # input_layernorm of next decoder layer. 
450+             # The scaling of DeepseekV2MOE output would be done in the forward 
451+             # of DeepseekV2MOE 
452+             hidden_states  *=  1.  /  self .routed_scaling_factor 
453+ 
454+         return  hidden_states , residual 
455+ 
411456
412457class  CustomDeepseekV2Model (nn .Module ):
413458
@@ -459,7 +504,9 @@ def forward(
459504        self ,
460505        input_ids : torch .Tensor ,
461506        positions : torch .Tensor ,
462-         intermediate_tensors : Optional [IntermediateTensors ],
507+         kv_caches : Optional [List [torch .Tensor ]] =  None ,
508+         attn_metadata : Optional [AttentionMetadata ] =  None ,
509+         intermediate_tensors : Optional [IntermediateTensors ] =  None ,
463510        inputs_embeds : Optional [torch .Tensor ] =  None ,
464511    ) ->  Union [torch .Tensor , IntermediateTensors ]:
465512        if  get_pp_group ().is_first_rank :
@@ -473,8 +520,13 @@ def forward(
473520            hidden_states  =  intermediate_tensors ["hidden_states" ]
474521            residual  =  intermediate_tensors ["residual" ]
475522
476-         for  layer  in  self .layers [self .start_layer :self .end_layer ]:
477-             hidden_states , residual  =  layer (positions , hidden_states , residual )
523+         for  i  in  range (self .start_layer , self .end_layer ):
524+             layer  =  self .layers [i ]
525+             hidden_states , residual  =  layer (
526+                 positions , hidden_states , residual ,
527+                 kv_caches [i  - 
528+                           self .start_layer ] if  kv_caches  is  not None  else  None ,
529+                 attn_metadata )
478530
479531        if  not  get_pp_group ().is_last_rank :
480532            return  IntermediateTensors ({
@@ -514,6 +566,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514566        self .make_empty_intermediate_tensors  =  (
515567            self .model .make_empty_intermediate_tensors )
516568
569+     def  forward (
570+         self ,
571+         input_ids : torch .Tensor ,
572+         positions : torch .Tensor ,
573+         kv_caches : Optional [List [torch .Tensor ]] =  None ,
574+         attn_metadata : Optional [AttentionMetadata ] =  None ,
575+         intermediate_tensors : Optional [IntermediateTensors ] =  None ,
576+         inputs_embeds : Optional [torch .Tensor ] =  None ,
577+     ) ->  Union [torch .Tensor , IntermediateTensors ]:
578+         hidden_states  =  self .model (input_ids , positions , kv_caches ,
579+                                    attn_metadata , intermediate_tensors ,
580+                                    inputs_embeds )
581+         return  hidden_states 
582+ 
517583
518584class  CustomDeepseekV3ForCausalLM (CustomDeepseekV2ForCausalLM ):
519585    pass 
0 commit comments