@@ -465,6 +465,10 @@ def __init__(
465465 self .enable_graph_mode = additional_config .get (
466466 "enable_graph_mode" , False )
467467
468+ self .cos = None
469+ self .sin = None
470+ self .debug_layer_idx = kwargs .get ('debug_layer_idx' , 0 )
471+
468472 def _v_up_proj_and_o_proj (self , x ):
469473 # Convert from (B, N, L) to (N, B, L)
470474 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -757,18 +761,20 @@ def forward(
757761 decode_ql_nope , decode_q_pe = \
758762 self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
759763 if self .running_in_graph :
760- seq_len = self .rotary_emb .max_position_embeddings
761- cos = self .rotary_emb .cos_cached [:seq_len ].to (
762- dtype = decode_q_pe .dtype )
763- sin = self .rotary_emb .sin_cached [:seq_len ].to (
764- dtype = decode_q_pe .dtype )
765- cos = cos [attn_metadata .decode .input_positions ]
766- sin = sin [attn_metadata .decode .input_positions ]
767- cos = cos [:, None , None , :]
768- sin = sin [:, None , None , :]
769- decode_q_pe = self .rope_single (decode_q_pe , cos , sin )
764+ # During the autoregressive decoding process, the cos and sin values are exactly the same for each layer
765+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
766+ seq_len = self .rotary_emb .max_position_embeddings
767+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (
768+ dtype = decode_q_pe .dtype )
769+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (
770+ dtype = decode_q_pe .dtype )
771+ self .cos = self .cos [attn_metadata .decode .input_positions ]
772+ self .sin = self .sin [attn_metadata .decode .input_positions ]
773+ self .cos = self .cos [:, None , None , :]
774+ self .sin = self .sin [:, None , None , :]
775+ decode_q_pe = self .rope_single (decode_q_pe , self .cos , self .sin )
770776 decode_k_pe , decode_k_nope = self .exec_kv (
771- hidden_states_or_kv_c_normed , cos , sin , kv_cache ,
777+ hidden_states_or_kv_c_normed , self . cos , self . sin , kv_cache ,
772778 attn_metadata .slot_mapping )
773779 else :
774780 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
0 commit comments