diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 8e584cca3657..cd8c08e5ab47 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -26,7 +26,8 @@ apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_dequantize, scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -174,6 +175,8 @@ def __init__( self.v_head_dim = v_head_dim self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj @@ -420,6 +423,24 @@ def _forward_decode( ) -> torch.Tensor: raise NotImplementedError + def apply_pure_rope( + self, + input_positions: torch.Tensor, + q_pe: torch.Tensor, + k_pe: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + seq_len = input_positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + + q_pe, k_pe = self.rotary_emb( + input_positions, + q_pe.reshape(seq_len, -1), + k_pe.reshape(seq_len, -1), + ) + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + + return q_pe, k_pe + def forward( self, layer: AttentionLayer, @@ -444,13 +465,14 @@ def forward( # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") + rope_fn = (self.rotary_emb + if self.use_yarn_rope else self.apply_pure_rope) if is_decode: q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = \ - self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe) + q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) else: assert is_prefill q = self.q_proj(hidden_states_or_q_c)[0]\ @@ -458,7 +480,7 @@ def forward( # TODO(lucas): there must be a nicer way to write this line q[..., self.qk_nope_head_dim:], k_pe = \ - self.rotary_emb( + rope_fn( attn_metadata.input_positions, q[..., self.qk_nope_head_dim:], k_pe) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f5fede4d8226..fdd584f9d6d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -414,7 +414,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.o_proj") - rope_scaling["rope_type"] = 'deepseek_yarn' + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index a4829aa1a572..81f82b182f1f 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -422,7 +422,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.o_proj") - rope_scaling["rope_type"] = 'deepseek_yarn' + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings,