7272from vllm_ascend .quantization .quant_config import AscendLinearMethod
7373from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
7474from vllm_ascend .utils import dispose_tensor
75+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
7576
7677VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
7778
@@ -502,7 +503,9 @@ def forward(
502503 positions : torch .Tensor ,
503504 hidden_states : torch .Tensor ,
504505 kv_cache : Optional [torch .Tensor ] = None ,
505- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
506+ attn_metadata : Optional [AttentionMetadata ] = None ,
507+ rotary_cos : Optional [torch .Tensor ] = None ,
508+ rotary_sin : Optional [torch .Tensor ] = None ) -> torch .Tensor :
506509 if self .q_lora_rank is not None :
507510 ckq = self .q_a_proj (hidden_states )[0 ]
508511 hidden_states_or_q_c = self .q_a_layernorm (ckq )
@@ -516,6 +519,8 @@ def forward(
516519 dtype = hidden_states_or_q_c .dtype ,
517520 device = hidden_states_or_q_c .device )
518521 forward_kwargs ['output' ] = output
522+ forward_kwargs ['rotary_cos' ] = rotary_cos
523+ forward_kwargs ['rotary_sin' ] = rotary_sin
519524
520525 output = self .mla_attn .impl .forward (self .mla_attn ,
521526 hidden_states_or_q_c ,
@@ -607,6 +612,8 @@ def forward(
607612 residual : Optional [torch .Tensor ],
608613 kv_cache : Optional [torch .Tensor ] = None ,
609614 attn_metadata : Optional [AttentionMetadata ] = None ,
615+ rotary_cos : Optional [torch .Tensor ] = None ,
616+ rotary_sin : Optional [torch .Tensor ] = None ,
610617 ) -> torch .Tensor :
611618 # Self Attention
612619 if residual is None :
@@ -626,6 +633,8 @@ def forward(
626633 hidden_states = hidden_states ,
627634 kv_cache = kv_cache ,
628635 attn_metadata = attn_metadata ,
636+ rotary_cos = rotary_cos ,
637+ rotary_sin = rotary_sin ,
629638 )
630639
631640 if hidden_states .dtype == torch .float16 :
@@ -703,9 +712,43 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
703712 make_empty_intermediate_tensors_factory (
704713 ["hidden_states" , "residual" ], config .hidden_size ))
705714
715+ ascend_config = get_ascend_config ()
716+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
717+
718+ rope_theta = getattr (config , "rope_theta" , 10000 )
719+ rope_scaling = getattr (config , "rope_scaling" , None )
720+ max_position_embeddings = getattr (config , "max_position_embeddings" ,
721+ 8192 )
722+ if rope_scaling :
723+ rope_scaling ["rope_type" ] = 'deepseek_yarn'
724+ self .rotary_emb = get_rope (config .qk_rope_head_dim ,
725+ rotary_dim = config .qk_rope_head_dim ,
726+ max_position = max_position_embeddings ,
727+ base = rope_theta ,
728+ rope_scaling = rope_scaling ,
729+ is_neox_style = False )
730+
706731 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
707732 return self .embed_tokens (input_ids )
708733
734+ def prepare_decoder_rotary_cos_sin (
735+ self , attn_metadata : Optional [AttentionMetadata ] = None
736+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
737+ if (attn_metadata is not None and attn_metadata .num_decodes is not None and
738+ attn_metadata .atten_state ):
739+ has_decode = attn_metadata .num_decodes > 0
740+ running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
741+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding ]
742+ if has_decode and running_in_graph :
743+ cos = self .rotary_emb .cos_cached
744+ sin = self .rotary_emb .sin_cached
745+ cos = cos [attn_metadata .decode .input_positions ]
746+ sin = sin [attn_metadata .decode .input_positions ]
747+ cos = cos [:, None , None , :]
748+ sin = sin [:, None , None , :]
749+ return cos , sin
750+ return None , None
751+
709752 def forward (
710753 self ,
711754 input_ids : torch .Tensor ,
@@ -726,13 +769,17 @@ def forward(
726769 hidden_states = intermediate_tensors ["hidden_states" ]
727770 residual = intermediate_tensors ["residual" ]
728771
772+ # In graph mode and v1 engine,
773+ # precomputing cos and sin can eliminate repeated calculations in each decode layer.
774+ rotary_cos , rotary_sin = self .prepare_decoder_rotary_cos_sin (attn_metadata )
775+
729776 for i in range (self .start_layer , self .end_layer ):
730777 layer = self .layers [i ]
731778 hidden_states , residual = layer (
732779 positions , hidden_states , residual ,
733780 kv_caches [i -
734781 self .start_layer ] if kv_caches is not None else None ,
735- attn_metadata )
782+ attn_metadata , rotary_cos , rotary_sin )
736783
737784 if not get_pp_group ().is_last_rank :
738785 return IntermediateTensors ({
0 commit comments