@@ -1000,6 +1000,10 @@ def __init__(
10001000 self .w_kc = None
10011001 self .w_vc = None
10021002
1003+ self .cos = None
1004+ self .sin = None
1005+ self .debug_layer_idx = extra_impl_args .get ('debug_layer_idx' , 0 )
1006+
10031007 self .enable_graph_mode = False
10041008 additional_config = get_current_vllm_config ().additional_config
10051009 if additional_config :
@@ -1128,17 +1132,18 @@ def forward(
11281132 q_nope , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ],
11291133 dim = - 1 )
11301134 if k_pe is None and attn_metadata .decode_metadata :
1131- seq_len = self .rotary_emb .max_position_embeddings
1132-
1133- cos = self .rotary_emb .cos_cached [:seq_len ].to (dtype = q_pe .dtype )
1134- sin = self .rotary_emb .sin_cached [:seq_len ].to (dtype = q_pe .dtype )
1135- cos = cos [attn_metadata .input_positions ]
1136- sin = sin [attn_metadata .input_positions ]
1137- cos = cos [:, None , None , :]
1138- sin = sin [:, None , None , :]
1139-
1140- q_pe = self .rope_single (q_pe , cos , sin )
1141- k_pe , k_nope = self .exec_kv (hidden_states_or_kv_c_normed , cos , sin ,
1135+ if self .debug_layer_idx == 0 or self .cos is None or self .sin is None :
1136+ seq_len = self .rotary_emb .max_position_embeddings
1137+
1138+ self .cos = self .rotary_emb .cos_cached [:seq_len ].to (dtype = q_pe .dtype )
1139+ self .sin = self .rotary_emb .sin_cached [:seq_len ].to (dtype = q_pe .dtype )
1140+ self .cos = self .cos [attn_metadata .input_positions ]
1141+ self .sin = self .sin [attn_metadata .input_positions ]
1142+ self .cos = self .cos [:, None , None , :]
1143+ self .sin = self .sin [:, None , None , :]
1144+
1145+ q_pe = self .rope_single (q_pe , self .cos , self .sin )
1146+ k_pe , k_nope = self .exec_kv (hidden_states_or_kv_c_normed , self .cos , self .sin ,
11421147 kv_cache , attn_metadata .slot_mapping )
11431148 else :
11441149 if k_pe is None :
0 commit comments