6767
6868import vllm_ascend .envs as envs_ascend
6969from vllm_ascend .ascend_config import get_ascend_config
70+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
7071from vllm_ascend .distributed .parallel_state import get_ep_group
7172from vllm_ascend .ops .fused_moe import AscendFusedMoE
7273from vllm_ascend .quantization .quant_config import AscendLinearMethod
@@ -500,12 +501,13 @@ def __init__(
500501 self .enable_multistream_mla = \
501502 ascend_config .torchair_graph_config .enable_multistream_mla
502503
503- def forward (
504- self ,
505- positions : torch .Tensor ,
506- hidden_states : torch .Tensor ,
507- kv_cache : Optional [torch .Tensor ] = None ,
508- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
504+ def forward (self ,
505+ positions : torch .Tensor ,
506+ hidden_states : torch .Tensor ,
507+ kv_cache : Optional [torch .Tensor ] = None ,
508+ attn_metadata : Optional [AttentionMetadata ] = None ,
509+ rotary_cos : Optional [torch .Tensor ] = None ,
510+ rotary_sin : Optional [torch .Tensor ] = None ) -> torch .Tensor :
509511 if self .q_lora_rank is not None :
510512 ckq = self .q_a_proj (hidden_states )[0 ]
511513 use_multistream_mla = (self .enable_multistream_mla
@@ -526,6 +528,8 @@ def forward(
526528 dtype = hidden_states_or_q_c .dtype ,
527529 device = hidden_states_or_q_c .device )
528530 forward_kwargs ['output' ] = output
531+ forward_kwargs ['rotary_cos' ] = rotary_cos
532+ forward_kwargs ['rotary_sin' ] = rotary_sin
529533
530534 output = self .mla_attn .impl .forward (self .mla_attn ,
531535 hidden_states_or_q_c ,
@@ -617,6 +621,8 @@ def forward(
617621 residual : Optional [torch .Tensor ],
618622 kv_cache : Optional [torch .Tensor ] = None ,
619623 attn_metadata : Optional [AttentionMetadata ] = None ,
624+ rotary_cos : Optional [torch .Tensor ] = None ,
625+ rotary_sin : Optional [torch .Tensor ] = None ,
620626 ) -> torch .Tensor :
621627 # Self Attention
622628 if residual is None :
@@ -636,6 +642,8 @@ def forward(
636642 hidden_states = hidden_states ,
637643 kv_cache = kv_cache ,
638644 attn_metadata = attn_metadata ,
645+ rotary_cos = rotary_cos ,
646+ rotary_sin = rotary_sin ,
639647 )
640648
641649 if hidden_states .dtype == torch .float16 :
@@ -713,9 +721,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
713721 make_empty_intermediate_tensors_factory (
714722 ["hidden_states" , "residual" ], config .hidden_size ))
715723
724+ ascend_config = get_ascend_config ()
725+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
726+
727+ rope_theta = getattr (config , "rope_theta" , 10000 )
728+ rope_scaling = getattr (config , "rope_scaling" , None )
729+ max_position_embeddings = getattr (config , "max_position_embeddings" ,
730+ 8192 )
731+ if rope_scaling :
732+ rope_scaling ["rope_type" ] = 'deepseek_yarn'
733+ self .rotary_emb = get_rope (config .qk_rope_head_dim ,
734+ rotary_dim = config .qk_rope_head_dim ,
735+ max_position = max_position_embeddings ,
736+ base = rope_theta ,
737+ rope_scaling = rope_scaling ,
738+ is_neox_style = False )
739+
716740 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
717741 return self .embed_tokens (input_ids )
718742
743+ def prepare_decoder_rotary_cos_sin (
744+ self ,
745+ attn_metadata : Optional [AttentionMetadata ] = None
746+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
747+ if (envs .VLLM_USE_V1 and attn_metadata is not None
748+ and attn_metadata .num_decodes is not None
749+ and attn_metadata .atten_state is not None ):
750+ has_decode = attn_metadata .num_decodes > 0
751+ running_in_graph = self .torchair_graph_enabled and attn_metadata .attn_state in [
752+ AscendAttentionState .DecodeOnly ,
753+ AscendAttentionState .SpecDecoding
754+ ]
755+ if has_decode and running_in_graph :
756+ cos = self .rotary_emb .cos_cached
757+ sin = self .rotary_emb .sin_cached
758+ cos = cos [attn_metadata .decode .input_positions ]
759+ sin = sin [attn_metadata .decode .input_positions ]
760+ cos = cos [:, None , None , :]
761+ sin = sin [:, None , None , :]
762+ return cos , sin
763+ return None , None
764+
719765 def forward (
720766 self ,
721767 input_ids : torch .Tensor ,
@@ -736,13 +782,18 @@ def forward(
736782 hidden_states = intermediate_tensors ["hidden_states" ]
737783 residual = intermediate_tensors ["residual" ]
738784
785+ # In graph mode and v1 engine,
786+ # precomputing cos and sin can eliminate repeated calculations in each decode layer.
787+ rotary_cos , rotary_sin = self .prepare_decoder_rotary_cos_sin (
788+ attn_metadata )
789+
739790 for i in range (self .start_layer , self .end_layer ):
740791 layer = self .layers [i ]
741792 hidden_states , residual = layer (
742793 positions , hidden_states , residual ,
743794 kv_caches [i -
744795 self .start_layer ] if kv_caches is not None else None ,
745- attn_metadata )
796+ attn_metadata , rotary_cos , rotary_sin )
746797
747798 if not get_pp_group ().is_last_rank :
748799 return IntermediateTensors ({
0 commit comments