2626 apply_fp8_linear_generic , current_platform_fp8_dtype , is_fp8 )
2727from vllm .model_executor .layers .quantization .utils .quant_utils import (
2828 scaled_dequantize , scaled_quantize )
29- from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
29+ from vllm .model_executor .layers .rotary_embedding import (
30+ DeepseekScalingRotaryEmbedding , RotaryEmbedding )
3031
3132try :
3233 from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -174,6 +175,8 @@ def __init__(
174175 self .v_head_dim = v_head_dim
175176
176177 self .rotary_emb = rotary_emb
178+ self .use_yarn_rope = isinstance (rotary_emb ,
179+ DeepseekScalingRotaryEmbedding )
177180 self .q_proj = q_proj
178181 self .kv_b_proj = kv_b_proj
179182 self .o_proj = o_proj
@@ -420,6 +423,24 @@ def _forward_decode(
420423 ) -> torch .Tensor :
421424 raise NotImplementedError
422425
426+ def apply_pure_rope (
427+ self ,
428+ input_positions : torch .Tensor ,
429+ q_pe : torch .Tensor ,
430+ k_pe : torch .Tensor ,
431+ ) -> tuple [torch .Tensor , torch .Tensor ]:
432+ seq_len = input_positions .size (0 )
433+ ori_q_pe_shape , ori_k_pe_shape = q_pe .shape , k_pe .shape
434+
435+ q_pe , k_pe = self .rotary_emb (
436+ input_positions ,
437+ q_pe .reshape (seq_len , - 1 ),
438+ k_pe .reshape (seq_len , - 1 ),
439+ )
440+ q_pe , k_pe = q_pe .view (ori_q_pe_shape ), k_pe .view (ori_k_pe_shape )
441+
442+ return q_pe , k_pe
443+
423444 def forward (
424445 self ,
425446 layer : AttentionLayer ,
@@ -444,21 +465,22 @@ def forward(
444465 # Restore head dim (for rotary embedding)
445466 k_pe = k_pe .unsqueeze (1 )
446467 assert hasattr (attn_metadata , "input_positions" )
468+ rope_fn = (self .rotary_emb
469+ if self .use_yarn_rope else self .apply_pure_rope )
447470
448471 if is_decode :
449472 q_nope = self ._q_proj_and_k_up_proj (hidden_states_or_q_c )
450473 q_pe = torch .matmul (hidden_states_or_q_c , self .W_QR )\
451474 .view (- 1 , self .num_heads , self .qk_rope_head_dim )
452- q_pe , k_pe = \
453- self .rotary_emb (attn_metadata .input_positions , q_pe , k_pe )
475+ q_pe , k_pe = rope_fn (attn_metadata .input_positions , q_pe , k_pe )
454476 else :
455477 assert is_prefill
456478 q = self .q_proj (hidden_states_or_q_c )[0 ]\
457479 .view (- 1 , self .num_heads , self .qk_head_dim )
458480
459481 # TODO(lucas): there must be a nicer way to write this line
460482 q [..., self .qk_nope_head_dim :], k_pe = \
461- self . rotary_emb (
483+ rope_fn (
462484 attn_metadata .input_positions ,
463485 q [..., self .qk_nope_head_dim :], k_pe )
464486
0 commit comments