7474from vllm_ascend .ops .fused_moe import AscendFusedMoE
7575from vllm_ascend .quantization .quant_config import AscendLinearMethod
7676from vllm_ascend .quantization .w8a8_dynamic import AscendW8A8DynamicLinearMethod
77+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
7778from vllm_ascend .utils import (dispose_tensor , npu_stream_switch ,
7879 npu_wait_tensor , vllm_version_is )
7980
@@ -794,6 +795,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
794795 self .make_empty_intermediate_tensors = (
795796 make_empty_intermediate_tensors_factory (
796797 ["hidden_states" , "residual" ], config .hidden_size ))
798+
799+ ascend_config = get_ascend_config ()
800+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
801+ self .max_position_embeddings = self .layers [0 ].self_attn .rotary_emb .max_position_embeddings
802+ self .scaling_factor = self .layers [0 ].self_attn .rotary_emb .scaling_factor
803+ self .cos_cached = self .layers [0 ].self_attn .rotary_emb .cos_cached
804+ self .sin_cached = self .layers [0 ].self_attn .rotary_emb .sin_cached
797805
798806 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
799807 return self .embed_tokens (input_ids )
@@ -819,6 +827,30 @@ def forward(
819827 residual = intermediate_tensors ["residual" ]
820828
821829 replace_allreduce = hidden_states .shape [0 ] % self .tp_size == 0
830+ # get cos, sin before layers
831+ self .running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
832+ AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
833+ ]
834+ if attn_metadata .num_decodes > 0 and self .running_in_graph :
835+ seq_len = self .max_position_embeddings * self .scaling_factor
836+ cos = self .cos_cached [:seq_len ].to (
837+ dtype = hidden_states .dtype )
838+ sin = self .sin_cached [:seq_len ].to (
839+ dtype = hidden_states .dtype )
840+ cos = cos [attn_metadata .decode .input_positions ]
841+ sin = sin [attn_metadata .decode .input_positions ]
842+ attn_metadata .decode .cos = cos [:, None , None , :]
843+ attn_metadata .decode .sin = sin [:, None , None , :]
844+ if attn_metadata .num_prefills > 0 and self .torchair_graph_enabled :
845+ seq_len = self .max_position_embeddings * self .scaling_factor
846+ cos = self .cos_cached [:seq_len ].to (
847+ dtype = hidden_states .dtype )
848+ sin = self .sin_cached [:seq_len ].to (
849+ dtype = hidden_states .dtype )
850+ cos = cos [attn_metadata .prefill .input_positions ]
851+ sin = sin [attn_metadata .prefill .input_positions ]
852+ attn_metadata .prefill .cos = cos [:, None , None , :]
853+ attn_metadata .prefill .sin = sin [:, None , None , :]
822854
823855 for i in range (self .start_layer , self .end_layer ):
824856 layer = self .layers [i ]
0 commit comments