Skip to content

Commit 1f98a7e

Browse files
committed
enable MTP on Gaudi
Mainly refer to: [Model][Speculative Decoding] DeepSeek MTP spec decode (vllm-project#12755) Enable MTP for HPU from deepseek_r1_upstream branch
1 parent 89a4ca5 commit 1f98a7e

17 files changed

+240
-93
lines changed

vllm/attention/layer.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,10 @@ def forward(
153153
query: torch.Tensor,
154154
key: torch.Tensor,
155155
value: torch.Tensor,
156-
kv_cache: torch.Tensor,
157-
attn_metadata: AttentionMetadata,
156+
# For some alternate attention backends like MLA the attention output
157+
# shape does not match the query shape, so we optionally let the model
158+
# definition specify the output tensor shape.
159+
output_shape: Optional[torch.Size] = None,
158160
) -> torch.Tensor:
159161
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
160162
# directly, use `self.kv_cache` and
@@ -164,17 +166,25 @@ def forward(
164166
if ctx_attn_metadata.enable_kv_scales_calculation:
165167
self.calc_kv_scales(key, value)
166168
if self.use_output:
167-
output = torch.empty_like(query)
168-
hidden_size = query.size(-1)
169-
# Reshape the query, key, and value tensors.
170-
# NOTE(woosuk): We do this outside the custom op to minimize the
171-
# CPU overheads from the non-CUDA-graph regions.
172-
query = query.view(-1, self.num_heads, self.head_size)
173-
output = output.view(-1, self.num_heads, self.head_size)
174-
if key is not None:
175-
key = key.view(-1, self.num_kv_heads, self.head_size)
176-
if value is not None:
177-
value = value.view(-1, self.num_kv_heads, self.head_size)
169+
output_shape = (output_shape
170+
if output_shape is not None else query.shape)
171+
output = torch.empty(output_shape,
172+
dtype=query.dtype,
173+
device=query.device)
174+
hidden_size = output_shape[-1]
175+
# We skip reshaping query, key and value tensors for the MLA
176+
# backend since these tensors have different semantics and are
177+
# processed differently.
178+
if not self.use_mla:
179+
# Reshape the query, key, and value tensors.
180+
# NOTE(woosuk): We do this outside the custom op to minimize the
181+
# CPU overheads from the non-CUDA-graph regions.
182+
query = query.view(-1, self.num_heads, self.head_size)
183+
output = output.view(-1, self.num_heads, self.head_size)
184+
if key is not None:
185+
key = key.view(-1, self.num_kv_heads, self.head_size)
186+
if value is not None:
187+
value = value.view(-1, self.num_kv_heads, self.head_size)
178188
if self.use_direct_call:
179189
forward_context: ForwardContext = get_forward_context()
180190
ctx_attn_metadata = forward_context.attn_metadata

vllm/config.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def get_hidden_size(self) -> int:
784784
def is_deepseek_mla(self) -> bool:
785785
return (hasattr(self.hf_text_config, "model_type")) \
786786
and (self.hf_text_config.model_type in \
787-
('deepseek_v2', 'deepseek_v3'))\
787+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\
788788
and (self.hf_text_config.kv_lora_rank is not None)
789789

790790
def get_head_size(self) -> int:
@@ -877,8 +877,12 @@ def get_num_attention_heads(self,
877877
def get_layers_start_end_indices(
878878
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
879879
from vllm.distributed.utils import get_pp_indices
880-
total_num_hidden_layers = getattr(self.hf_text_config,
881-
"num_hidden_layers", 0)
880+
if self.hf_text_config.model_type == "deepseek_mtp":
881+
total_num_hidden_layers = getattr(self.hf_text_config,
882+
"num_nextn_predict_layers", 0)
883+
else:
884+
total_num_hidden_layers = getattr(self.hf_text_config,
885+
"num_hidden_layers", 0)
882886
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
883887
pp_size = parallel_config.pipeline_parallel_size
884888
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
@@ -1741,6 +1745,18 @@ def compute_hash(self) -> str:
17411745
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
17421746
return hash_str
17431747

1748+
@staticmethod
1749+
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
1750+
if hf_config.model_type == "deepseek_v3":
1751+
hf_config.model_type = "deepseek_mtp"
1752+
if hf_config.model_type == "deepseek_mtp":
1753+
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
1754+
hf_config.update({
1755+
"n_predict": n_predict,
1756+
"architectures": ["DeepSeekMTPModel"]
1757+
})
1758+
return hf_config
1759+
17441760
@staticmethod
17451761
def maybe_create_spec_config(
17461762
target_model_config: ModelConfig,
@@ -1826,9 +1842,15 @@ def maybe_create_spec_config(
18261842

18271843
if speculative_model is None:
18281844
if num_speculative_tokens is not None:
1829-
raise ValueError("num_speculative_tokens was provided without "
1845+
if target_model_config.hf_text_config.model_type \
1846+
== "deepseek_v3":
1847+
# use the draft model from the same model:
1848+
speculative_model = target_model_config.model
1849+
else:
1850+
raise ValueError("num_speculative_tokens was provided without "
18301851
"speculative_model.")
1831-
return None
1852+
else:
1853+
return None
18321854

18331855
if (speculative_disable_by_batch_size is not None
18341856
and speculative_disable_by_batch_size < 2):
@@ -1882,6 +1904,7 @@ def maybe_create_spec_config(
18821904
max_seq_len_to_capture=target_model_config.
18831905
max_seq_len_to_capture,
18841906
max_logprobs=target_model_config.max_logprobs,
1907+
hf_overrides=SpeculativeConfig.hf_config_override,
18851908
)
18861909

18871910
draft_hf_config = draft_model_config.hf_config
@@ -2003,8 +2026,9 @@ def _verify_and_get_draft_model_tensor_parallel_size(
20032026
speculative_draft_tensor_parallel_size = 1
20042027
if target_parallel_config.tensor_parallel_size > 1:
20052028
logger.warning(
2006-
"MLPSpeculator cannot currently be run with tp>1; "
2007-
"setting speculative_draft_tensor_parallel_size=1")
2029+
"%s cannot currently be run with tp>1; "
2030+
"setting speculative_draft_tensor_parallel_size=1",
2031+
draft_hf_config.model_type)
20082032
else:
20092033
speculative_draft_tensor_parallel_size = \
20102034
target_parallel_config.tensor_parallel_size
@@ -2039,6 +2063,8 @@ def create_draft_parallel_config(
20392063
ray_workers_use_nsight=target_parallel_config.
20402064
ray_workers_use_nsight,
20412065
placement_group=target_parallel_config.placement_group,
2066+
enable_expert_parallel=target_parallel_config.
2067+
enable_expert_parallel,
20422068
)
20432069

20442070
return draft_parallel_config

vllm/model_executor/layers/layernorm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def forward_hpu(
106106
residual: Optional[torch.Tensor] = None,
107107
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
108108
from vllm_hpu_extension.kernels import rms_norm
109+
if x.dim() < 3:
110+
# fix an known bug before synapse 1.21 release
111+
HPUFusedRMSNorm = None
109112
HPUFusedRMSNorm = rms_norm()
110113
if HPUFusedRMSNorm is None:
111114
return self.forward_native(x, residual)

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from importlib.util import find_spec
55
from typing import Dict, Optional, Tuple
66

7+
import os
78
import torch
89
import torch.jit
910

@@ -59,6 +60,10 @@ def __init__(self,
5960
else:
6061
logger.info("Use pytorch for rejection sampling.")
6162

63+
if os.environ.get('VLLM_MTP_PRINT_ACCPET_RATE', '1') != '0':
64+
self.total_true = 0
65+
self.total_false = 0
66+
6267
def forward(
6368
self,
6469
target_with_bonus_probs: torch.Tensor,
@@ -298,6 +303,15 @@ def _get_accepted(
298303
torch.full((1, ), 1, device=target_probs.device))
299304
accepted = uniform_rand < capped_ratio
300305

306+
if os.environ.get('VLLM_MTP_PRINT_ACCPET_RATE', '1') != '0':
307+
current_true = accepted.sum().item()
308+
current_false = accepted.numel() - current_true
309+
self.total_true += current_true
310+
self.total_false += current_false
311+
total = self.total_true + self.total_false
312+
ratio_true = self.total_true / total if total != 0 else 0.0
313+
print(f"Accepted ratio: {ratio_true:.2%} ({self.total_true}/{total})")
314+
301315
return accepted
302316

303317
def _get_recovered_probs(

vllm/model_executor/models/deepseek_v2.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,6 @@ def forward(
297297
self,
298298
positions: torch.Tensor,
299299
hidden_states: torch.Tensor,
300-
kv_cache: torch.Tensor,
301-
attn_metadata: AttentionMetadata,
302300
) -> torch.Tensor:
303301
if is_hpu:
304302
# need reshape from tensor(x0, y0) to tensor(x1) for hpu
@@ -353,7 +351,7 @@ def forward(
353351
q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1])
354352
k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1])
355353
v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1])
356-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
354+
attn_output = self.attn(q, k, v)
357355
if is_hpu:
358356
# need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
359357
attn_output = attn_output.reshape(
@@ -500,8 +498,6 @@ def forward(
500498
self,
501499
positions: torch.Tensor,
502500
hidden_states: torch.Tensor,
503-
kv_cache: torch.Tensor,
504-
attn_metadata: AttentionMetadata,
505501
) -> torch.Tensor:
506502
if self.q_lora_rank is not None:
507503
ckq = self.q_a_proj(hidden_states)[0]
@@ -511,8 +507,7 @@ def forward(
511507
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
512508
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
513509
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
514-
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
515-
attn_metadata)
510+
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape)
516511

517512

518513
class DeepseekV2DecoderLayer(nn.Module):
@@ -581,8 +576,6 @@ def forward(
581576
self,
582577
positions: torch.Tensor,
583578
hidden_states: torch.Tensor,
584-
kv_cache: torch.Tensor,
585-
attn_metadata: AttentionMetadata,
586579
residual: Optional[torch.Tensor],
587580
) -> torch.Tensor:
588581
# Self Attention
@@ -595,8 +588,6 @@ def forward(
595588
hidden_states = self.self_attn(
596589
positions=positions,
597590
hidden_states=hidden_states,
598-
kv_cache=kv_cache,
599-
attn_metadata=attn_metadata,
600591
)
601592

602593
# Fully Connected
@@ -657,8 +648,6 @@ def forward(
657648
self,
658649
input_ids: torch.Tensor,
659650
positions: torch.Tensor,
660-
kv_caches: List[torch.Tensor],
661-
attn_metadata: AttentionMetadata,
662651
intermediate_tensors: Optional[IntermediateTensors],
663652
inputs_embeds: Optional[torch.Tensor] = None,
664653
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -673,12 +662,8 @@ def forward(
673662
hidden_states = intermediate_tensors["hidden_states"]
674663
residual = intermediate_tensors["residual"]
675664

676-
for i in range(self.start_layer, self.end_layer):
677-
layer = self.layers[i]
678-
kvcaches = None if kv_caches is None else kv_caches[i - self.start_layer]
679-
hidden_states, residual = layer(positions, hidden_states,
680-
kvcaches,
681-
attn_metadata, residual)
665+
for layer in self.layers[self.start_layer:self.end_layer]:
666+
hidden_states, residual = layer(positions, hidden_states, residual)
682667

683668
if not get_pp_group().is_last_rank:
684669
return IntermediateTensors({
@@ -715,13 +700,10 @@ def forward(
715700
self,
716701
input_ids: torch.Tensor,
717702
positions: torch.Tensor,
718-
kv_caches: List[torch.Tensor],
719-
attn_metadata: AttentionMetadata,
720703
intermediate_tensors: Optional[IntermediateTensors] = None,
721704
inputs_embeds: Optional[torch.Tensor] = None,
722705
) -> Union[torch.Tensor, IntermediateTensors]:
723-
hidden_states = self.model(input_ids, positions, kv_caches,
724-
attn_metadata, intermediate_tensors,
706+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
725707
inputs_embeds)
726708
return hidden_states
727709

@@ -778,13 +760,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
778760
if "rotary_emb.inv_freq" in name:
779761
continue
780762

781-
# TODO(simon): support nextn predict layers
782-
if hasattr(self.config, "num_nextn_predict_layers"
783-
) and self.config.num_nextn_predict_layers > 0:
784-
assert self.config.num_nextn_predict_layers == 1
785-
layer_idx = self.config.num_hidden_layers
786-
if name.startswith(f"model.layers.{layer_idx}"):
787-
continue
763+
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
764+
if spec_layer is not None:
765+
continue # skip spec decode layers for main model
788766

789767
for (param_name, weight_name, shard_id) in stacked_params_mapping:
790768
# Skip non-stacked layers and experts (experts handled below).
@@ -860,3 +838,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
860838

861839
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
862840
pass
841+
842+
843+
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
844+
weight_name: str) -> Optional[int]:
845+
if hasattr(config,
846+
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
847+
> 0):
848+
layer_idx = config.num_hidden_layers
849+
for i in range(config.num_nextn_predict_layers):
850+
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
851+
return layer_idx + i
852+
return None

vllm/model_executor/models/interfaces_base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
3+
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union,
44
overload, runtime_checkable)
55

66
import torch
@@ -11,7 +11,6 @@
1111
from vllm.utils import supports_kw
1212

1313
if TYPE_CHECKING:
14-
from vllm.attention import AttentionMetadata
1514
from vllm.config import VllmConfig
1615
from vllm.model_executor.layers.pooler import PoolerOutput
1716
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -46,8 +45,6 @@ def forward(
4645
self,
4746
input_ids: torch.Tensor,
4847
positions: torch.Tensor,
49-
kv_caches: List[torch.Tensor],
50-
attn_metadata: "AttentionMetadata",
5148
) -> T_co:
5249
...
5350

@@ -62,7 +59,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
6259
if not callable(model_forward):
6360
return False
6461

65-
vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
62+
vllm_kws = ("input_ids", "positions")
6663
missing_kws = tuple(kw for kw in vllm_kws
6764
if not supports_kw(model_forward, kw))
6865

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186

187187
_SPECULATIVE_DECODING_MODELS = {
188188
"EAGLEModel": ("eagle", "EAGLE"),
189+
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
189190
"MedusaModel": ("medusa", "Medusa"),
190191
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
191192
}

vllm/sequence.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,8 @@ class ExecuteModelRequest(
13051305
previous_hidden_states: Optional[HiddenStates] = None
13061306
# The number of forward steps to run.
13071307
num_steps: int = 1
1308+
# The step index for spec model input.
1309+
spec_step_idx: Optional[int] = None
13081310
# Finished request ids since last step.
13091311
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
13101312
# The last sampled token ids for multi step decoding.

0 commit comments

Comments
 (0)