@@ -446,6 +446,10 @@ def __init__(
446446 ascend_config = get_ascend_config ()
447447 self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
448448
449+ self .cos = None
450+ self .sin = None
451+ self .debug_layer_idx = kwargs .get ('debug_layer_idx' , 0 )
452+
449453 def _v_up_proj_and_o_proj (self , x ):
450454 # Convert from (B, N, L) to (N, B, L)
451455 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -738,19 +742,20 @@ def forward(
738742 decode_ql_nope , decode_q_pe = \
739743 self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
740744 if self .running_in_graph :
741- seq_len = self .rotary_emb .max_position_embeddings
742- cos = self .rotary_emb .cos_cached [:seq_len ].to (
743- dtype = decode_q_pe .dtype )
744- sin = self .rotary_emb .sin_cached [:seq_len ].to (
745- dtype = decode_q_pe .dtype )
746- cos = cos [attn_metadata .decode .input_positions ]
747- sin = sin [attn_metadata .decode .input_positions ]
748- cos = cos [:, None , None , :]
749- sin = sin [:, None , None , :]
750-
751- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
745+ # During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
746+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
747+ seq_len = self .rotary_emb .max_position_embeddings
748+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (
749+ dtype = decode_q_pe .dtype )
750+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (
751+ dtype = decode_q_pe .dtype )
752+ self .cos = self .cos [attn_metadata .decode .input_positions ]
753+ self .sin = self .sin [attn_metadata .decode .input_positions ]
754+ self .cos = self .cos [:, None , None , :]
755+ self .sin = self .sin [:, None , None , :]
756+ decode_q_pe = self .rope_single (decode_q_pe , self .cos , self .sin )
752757 decode_k_pe , decode_k_nope = self .exec_kv (
753- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
758+ hidden_states_or_kv_c_normed , self . cos , self . sin , kv_cache ,
754759 attn_metadata .slot_mapping )
755760 else :
756761 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
0 commit comments