Skip to content

Commit 98fd089

Browse files
authored
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models (#12729)
1 parent 249824c commit 98fd089

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

vllm/attention/backends/mla/utils.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
2727
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2828
scaled_dequantize, scaled_quantize)
29-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
29+
from vllm.model_executor.layers.rotary_embedding import (
30+
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
3031

3132
try:
3233
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -174,6 +175,8 @@ def __init__(
174175
self.v_head_dim = v_head_dim
175176

176177
self.rotary_emb = rotary_emb
178+
self.use_yarn_rope = isinstance(rotary_emb,
179+
DeepseekScalingRotaryEmbedding)
177180
self.q_proj = q_proj
178181
self.kv_b_proj = kv_b_proj
179182
self.o_proj = o_proj
@@ -420,6 +423,24 @@ def _forward_decode(
420423
) -> torch.Tensor:
421424
raise NotImplementedError
422425

426+
def apply_pure_rope(
427+
self,
428+
input_positions: torch.Tensor,
429+
q_pe: torch.Tensor,
430+
k_pe: torch.Tensor,
431+
) -> tuple[torch.Tensor, torch.Tensor]:
432+
seq_len = input_positions.size(0)
433+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
434+
435+
q_pe, k_pe = self.rotary_emb(
436+
input_positions,
437+
q_pe.reshape(seq_len, -1),
438+
k_pe.reshape(seq_len, -1),
439+
)
440+
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
441+
442+
return q_pe, k_pe
443+
423444
def forward(
424445
self,
425446
layer: AttentionLayer,
@@ -444,21 +465,22 @@ def forward(
444465
# Restore head dim (for rotary embedding)
445466
k_pe = k_pe.unsqueeze(1)
446467
assert hasattr(attn_metadata, "input_positions")
468+
rope_fn = (self.rotary_emb
469+
if self.use_yarn_rope else self.apply_pure_rope)
447470

448471
if is_decode:
449472
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
450473
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
451474
.view(-1, self.num_heads, self.qk_rope_head_dim)
452-
q_pe, k_pe = \
453-
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
475+
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
454476
else:
455477
assert is_prefill
456478
q = self.q_proj(hidden_states_or_q_c)[0]\
457479
.view(-1, self.num_heads, self.qk_head_dim)
458480

459481
# TODO(lucas): there must be a nicer way to write this line
460482
q[..., self.qk_nope_head_dim:], k_pe = \
461-
self.rotary_emb(
483+
rope_fn(
462484
attn_metadata.input_positions,
463485
q[..., self.qk_nope_head_dim:], k_pe)
464486

vllm/model_executor/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def __init__(
414414
quant_config=quant_config,
415415
prefix=f"{prefix}.o_proj")
416416

417-
rope_scaling["rope_type"] = 'deepseek_yarn'
417+
if rope_scaling:
418+
rope_scaling["rope_type"] = 'deepseek_yarn'
418419
self.rotary_emb = get_rope(qk_rope_head_dim,
419420
rotary_dim=qk_rope_head_dim,
420421
max_position=max_position_embeddings,

vllm/model_executor/models/deepseek_v3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,8 @@ def __init__(
422422
quant_config=quant_config,
423423
prefix=f"{prefix}.o_proj")
424424

425-
rope_scaling["rope_type"] = 'deepseek_yarn'
425+
if rope_scaling:
426+
rope_scaling["rope_type"] = 'deepseek_yarn'
426427
self.rotary_emb = get_rope(qk_rope_head_dim,
427428
rotary_dim=qk_rope_head_dim,
428429
max_position=max_position_embeddings,

0 commit comments

Comments
 (0)