6969from vllm_ascend .ops .fused_moe import AscendFusedMoE
7070from vllm_ascend .quantization .quant_config import AscendLinearMethod
7171from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
72+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
7273from vllm_ascend .utils import (dispose_tensor , npu_stream_switch ,
7374 npu_wait_tensor )
7475
@@ -671,6 +672,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
671672 self .make_empty_intermediate_tensors = (
672673 make_empty_intermediate_tensors_factory (
673674 ["hidden_states" , "residual" ], config .hidden_size ))
675+
676+ ascend_config = get_ascend_config ()
677+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
678+ self .max_position_embeddings = self .layers [0 ].self_attn .rotary_emb .max_position_embeddings
679+ self .cos_cached = self .layers [0 ].self_attn .rotary_emb .cos_cached
680+ self .sin_cached = self .layers [0 ].self_attn .rotary_emb .sin_cached
674681
675682 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
676683 return self .embed_tokens (input_ids )
@@ -695,6 +702,31 @@ def forward(
695702 hidden_states = intermediate_tensors ["hidden_states" ]
696703 residual = intermediate_tensors ["residual" ]
697704
705+ # get cos, sin before layers
706+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
707+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
708+ ]
709+ if attn_metadata .num_decodes > 0 and self .running_in_graph :
710+ seq_len = self .max_position_embeddings
711+ cos = self .cos_cached [:seq_len ].to (
712+ dtype = hidden_states .dtype )
713+ sin = self .sin_cached [:seq_len ].to (
714+ dtype = hidden_states .dtype )
715+ cos = cos [attn_metadata .decode .input_positions ]
716+ sin = sin [attn_metadata .decode .input_positions ]
717+ attn_metadata .decode .cos = cos [:, None , None , :]
718+ attn_metadata .decode .sin = sin [:, None , None , :]
719+ if attn_metadata .num_prefills > 0 and self .torchair_graph_enabled :
720+ seq_len = self .rotary_emb .max_position_embeddings
721+ cos = self .cos_cached [:seq_len ].to (
722+ dtype = hidden_states .dtype )
723+ sin = self .sin_cached [:seq_len ].to (
724+ dtype = hidden_states .dtype )
725+ cos = cos [attn_metadata .prefill .input_positions ]
726+ sin = sin [attn_metadata .prefill .input_positions ]
727+ attn_metadata .prefill .cos = cos [:, None , None , :]
728+ attn_metadata .prefill .sin = sin [:, None , None , :]
729+
698730 for i in range (self .start_layer , self .end_layer ):
699731 layer = self .layers [i ]
700732 hidden_states , residual = layer (
0 commit comments