2626# """Inference-only DeepseekV2/DeepseekV3 model."""
2727
2828import os
29- from typing import Any , Dict , Optional , Union
29+ from typing import Any , Dict , List , Optional , Union
3030
3131import torch
3232import torch .distributed as dist
3333from torch import nn
3434from transformers import PretrainedConfig
35- from vllm .attention import Attention
35+ from vllm .attention import Attention , AttentionMetadata
3636from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
3737 get_current_vllm_config )
3838from vllm .distributed import (get_dp_group , get_pp_group ,
6464from vllm .sequence import IntermediateTensors
6565
6666from vllm_ascend .ops .fused_moe import AscendFusedMoE
67- from vllm_ascend .utils import VLLM_ENABLE_GRAPH_MODE
6867
6968
7069class CustomDeepseekV2MoE (nn .Module ):
@@ -133,7 +132,7 @@ def __init__(
133132 vllm_config = get_current_vllm_config ()
134133 self .dp_size = get_dp_group ().world_size
135134 batch_size = vllm_config .scheduler_config .max_num_seqs
136- self .enable_mc2 = int (os .environ .get ("VLLM_ENABLE_MC2" , 0 )) == 1
135+ self .enable_mc2 = int (os .environ .get ("VLLM_ENABLE_MC2" , '0' )) == 1
137136
138137 params_dtype = torch .get_default_dtype ()
139138 self .final_hidden_states = torch .zeros (
@@ -309,38 +308,36 @@ def __init__(
309308
310309 self .prefix = prefix
311310 self .debug_layer_idx = int (self .prefix .split ("." )[- 2 ])
312- if VLLM_ENABLE_GRAPH_MODE == "1" :
313- self .forward = self .forward_torchair
314- else :
315- self .forward = self .forward_eager # type: ignore
311+ self .enable_graph_mode = False
312+ additional_config = get_current_vllm_config ().additional_config
313+ if additional_config :
314+ self .enable_graph_mode = additional_config .get (
315+ "enable_graph_mode" , False )
316316
317- def forward_torchair (self ,
318- positions : torch .Tensor ,
319- hidden_states : torch .Tensor ,
320- kv_cache : torch .Tensor = None ,
321- attn_metadata = None ):
317+ def forward (
318+ self ,
319+ positions : torch .Tensor ,
320+ hidden_states : torch .Tensor ,
321+ kv_cache : Optional [torch .Tensor ] = None ,
322+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
322323 if self .q_lora_rank is not None :
323324 ckq = self .q_a_proj (hidden_states )[0 ]
324325 hidden_states_or_q_c = self .q_a_layernorm (ckq )
325326 else :
326327 hidden_states_or_q_c = hidden_states
327- return self .mla_attn (hidden_states_or_q_c , hidden_states , None ,
328- kv_cache , attn_metadata )
329-
330- def forward_eager (self , positions : torch .Tensor ,
331- hidden_states : torch .Tensor ):
332- if self .q_lora_rank is not None :
333- ckq = self .q_a_proj (hidden_states )[0 ]
334- hidden_states_or_q_c = self .q_a_layernorm (ckq )
328+ if self .enable_graph_mode :
329+ return self .mla_attn .impl .forward (self .mla_attn ,
330+ hidden_states_or_q_c ,
331+ hidden_states , None , kv_cache ,
332+ attn_metadata )
335333 else :
336- hidden_states_or_q_c = hidden_states
337- kv_c , k_pe = self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
338- [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
339- kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
340- return self .mla_attn (hidden_states_or_q_c ,
341- kv_c_normed ,
342- k_pe ,
343- output_shape = hidden_states .shape )
334+ kv_c , k_pe = self .kv_a_proj_with_mqa (hidden_states )[0 ].split (
335+ [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
336+ kv_c_normed = self .kv_a_layernorm (kv_c .contiguous ())
337+ return self .mla_attn (hidden_states_or_q_c ,
338+ kv_c_normed ,
339+ k_pe ,
340+ output_shape = hidden_states .shape )
344341
345342
346343class CustomDeepseekV2DecoderLayer (DeepseekV2DecoderLayer ):
@@ -408,6 +405,54 @@ def __init__(
408405 eps = config .rms_norm_eps )
409406 self .routed_scaling_factor = config .routed_scaling_factor
410407
408+ def forward (
409+ self ,
410+ positions : torch .Tensor ,
411+ hidden_states : torch .Tensor ,
412+ residual : Optional [torch .Tensor ],
413+ kv_cache : Optional [torch .Tensor ] = None ,
414+ attn_metadata : Optional [AttentionMetadata ] = None ,
415+ ) -> torch .Tensor :
416+ # Self Attention
417+ if residual is None :
418+ residual = hidden_states
419+ hidden_states = self .input_layernorm (hidden_states )
420+ else :
421+ hidden_states , residual = self .input_layernorm (
422+ hidden_states , residual )
423+ hidden_states = self .self_attn (
424+ positions = positions ,
425+ hidden_states = hidden_states ,
426+ kv_cache = kv_cache ,
427+ attn_metadata = attn_metadata ,
428+ )
429+
430+ if hidden_states .dtype == torch .float16 :
431+ # Fix FP16 overflow
432+ # We scale both hidden_states and residual before
433+ # rmsnorm, and rmsnorm result would not affect by scale.
434+ hidden_states *= 1. / self .routed_scaling_factor
435+ if self .layer_idx == 0 :
436+ # The residual is shared by all layers, we only scale it on
437+ # first layer.
438+ residual *= 1. / self .routed_scaling_factor
439+
440+ # Fully Connected
441+ hidden_states , residual = self .post_attention_layernorm (
442+ hidden_states , residual )
443+ hidden_states = self .mlp (hidden_states )
444+
445+ if isinstance (self .mlp ,
446+ DeepseekV2MLP ) and hidden_states .dtype == torch .float16 :
447+ # Fix FP16 overflow
448+ # Scaling the DeepseekV2MLP output, it is the input of
449+ # input_layernorm of next decoder layer.
450+ # The scaling of DeepseekV2MOE output would be done in the forward
451+ # of DeepseekV2MOE
452+ hidden_states *= 1. / self .routed_scaling_factor
453+
454+ return hidden_states , residual
455+
411456
412457class CustomDeepseekV2Model (nn .Module ):
413458
@@ -459,7 +504,9 @@ def forward(
459504 self ,
460505 input_ids : torch .Tensor ,
461506 positions : torch .Tensor ,
462- intermediate_tensors : Optional [IntermediateTensors ],
507+ kv_caches : Optional [List [torch .Tensor ]] = None ,
508+ attn_metadata : Optional [AttentionMetadata ] = None ,
509+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
463510 inputs_embeds : Optional [torch .Tensor ] = None ,
464511 ) -> Union [torch .Tensor , IntermediateTensors ]:
465512 if get_pp_group ().is_first_rank :
@@ -473,8 +520,13 @@ def forward(
473520 hidden_states = intermediate_tensors ["hidden_states" ]
474521 residual = intermediate_tensors ["residual" ]
475522
476- for layer in self .layers [self .start_layer :self .end_layer ]:
477- hidden_states , residual = layer (positions , hidden_states , residual )
523+ for i in range (self .start_layer , self .end_layer ):
524+ layer = self .layers [i ]
525+ hidden_states , residual = layer (
526+ positions , hidden_states , residual ,
527+ kv_caches [i -
528+ self .start_layer ] if kv_caches is not None else None ,
529+ attn_metadata )
478530
479531 if not get_pp_group ().is_last_rank :
480532 return IntermediateTensors ({
@@ -514,6 +566,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514566 self .make_empty_intermediate_tensors = (
515567 self .model .make_empty_intermediate_tensors )
516568
569+ def forward (
570+ self ,
571+ input_ids : torch .Tensor ,
572+ positions : torch .Tensor ,
573+ kv_caches : Optional [List [torch .Tensor ]] = None ,
574+ attn_metadata : Optional [AttentionMetadata ] = None ,
575+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
576+ inputs_embeds : Optional [torch .Tensor ] = None ,
577+ ) -> Union [torch .Tensor , IntermediateTensors ]:
578+ hidden_states = self .model (input_ids , positions , kv_caches ,
579+ attn_metadata , intermediate_tensors ,
580+ inputs_embeds )
581+ return hidden_states
582+
517583
518584class CustomDeepseekV3ForCausalLM (CustomDeepseekV2ForCausalLM ):
519585 pass
0 commit comments