@@ -408,6 +408,10 @@ def __init__(
408408 self .enable_graph_mode = additional_config .get (
409409 "enable_graph_mode" , False )
410410
411+ self .cos = None
412+ self .sin = None
413+ self .debug_layer_idx = extra_impl_args .get ('debug_layer_idx' , 0 )
414+
411415 def _v_up_proj_and_o_proj (self , x ):
412416 # Convert from (B, N, L) to (N, B, L)
413417 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -700,18 +704,20 @@ def forward(
700704 decode_ql_nope , decode_q_pe = \
701705 self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
702706 if self .running_in_graph :
703- seq_len = self .rotary_emb .max_position_embeddings
704- cos = self .rotary_emb .cos_cached [:seq_len ].to (
705- dtype = decode_q_pe .dtype )
706- sin = self .rotary_emb .sin_cached [:seq_len ].to (
707- dtype = decode_q_pe .dtype )
708- cos = cos [attn_metadata .decode .input_positions ]
709- sin = sin [attn_metadata .decode .input_positions ]
710- cos = cos [:, None , None , :]
711- sin = sin [:, None , None , :]
712- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
707+ # During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
708+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
709+ seq_len = self .rotary_emb .max_position_embeddings
710+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (
711+ dtype = decode_q_pe .dtype )
712+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (
713+ dtype = decode_q_pe .dtype )
714+ self .cos = self .cos [attn_metadata .decode .input_positions ]
715+ self .sin = self .sin [attn_metadata .decode .input_positions ]
716+ self .cos = self .cos [:, None , None , :]
717+ self .sin = self .sin [:, None , None , :]
718+ decode_q_pe = self .rope_single (decode_q_pe , self .cos , self .sin )
713719 decode_k_pe , decode_k_nope = self .exec_kv (
714- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
720+ hidden_states_or_kv_c_normed , self . cos , self . sin , kv_cache ,
715721 attn_metadata .slot_mapping )
716722 else :
717723 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
0 commit comments