1414 # vllm_flash_attn is not installed, try the ROCm FA metadata
1515 from vllm .attention .backends .rocm_flash_attn import (
1616 ROCmFlashAttentionMetadata as FlashAttentionMetadata )
17+ try :
18+ from vllm .attention .backends .triton_mla import TritonMLAMetadata
19+ except (ModuleNotFoundError , ImportError ):
20+ TritonMLAMetadata = FlashAttentionMetadata
21+
1722except (ModuleNotFoundError , ImportError ) as err :
1823 raise RuntimeError (
1924 "Draft model speculative decoding currently only supports "
@@ -57,7 +62,7 @@ def __init__(self, model_runner: ModelRunnerBase):
5762 "return_hidden_states is not supported for TP1DraftModelRunner."
5863 )
5964 super ().__init__ (model_runner )
60-
65+ self . mtp = False
6166 self .indices_of_seq_with_bonus_tokens = None
6267
6368 def _update_sampling_metadata (self , sampling_metadata , num_seqs ,
@@ -92,7 +97,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
9297
9398 # Update attn_metadata
9499 attn_metadata = model_input .attn_metadata
95- assert isinstance (attn_metadata , FlashAttentionMetadata )
100+ assert isinstance (attn_metadata ,
101+ (FlashAttentionMetadata , TritonMLAMetadata ))
96102
97103 attn_metadata .advance_step (model_input , sampled_token_ids ,
98104 self .block_size , num_seqs , num_queries )
@@ -193,6 +199,7 @@ def execute_model(
193199 # iteration invokes this function only once
194200 # (Look at multi-step-worker code)
195201 is_fallback = num_steps == 1
202+ self .mtp = self .model .config .model_type == "deepseek_mtp"
196203 if not is_fallback :
197204 # Since we do not broadcast data inside execute_model anymore,
198205 # we need to figure out the best way to support TP > 1 in this
@@ -269,6 +276,9 @@ def execute_model(
269276 hidden_states = previous_hidden_states
270277
271278 outputs : List [SamplerOutput ] = []
279+ input_tokens = model_input .input_tokens
280+ input_positions = model_input .input_positions
281+ attn_metadata = model_input .attn_metadata
272282 for step in range (num_steps ):
273283 multi_modal_kwargs = model_input .multi_modal_kwargs or {}
274284
@@ -277,37 +287,64 @@ def execute_model(
277287
278288 compute_logits_kwargs = {}
279289 # Run model
280- if hasattr (self .model .config , "num_nextn_predict_layers" ):
290+ spec_step_idx = kwargs .get ("spec_step_idx" , 0 )
291+ if self .model_config .requires_multi_step_decode :
281292 # for DeepSeek MTP only to use the corresponding layer for
282293 # each step
283294 spec_step_idx = kwargs .get ("spec_step_idx" , step )
284- model_execute_kwargs ["spec_step_idx" ] = spec_step_idx
285- compute_logits_kwargs ["spec_step_idx" ] = spec_step_idx
286- with set_forward_context (model_input .attn_metadata ,
287- self .vllm_config ):
295+ if spec_step_idx >= 0 :
296+ model_execute_kwargs ["spec_step_idx" ] = spec_step_idx
297+ compute_logits_kwargs ["spec_step_idx" ] = spec_step_idx
298+
299+ graph_batch_size = model_input .input_tokens .shape [0 ]
300+ graph_idx = self .parallel_config .pipeline_parallel_size * spec_step_idx + model_input .virtual_engine
301+ model_executable = self .graph_runners [graph_idx ][graph_batch_size ]
302+ elif not use_cuda_graph :
303+ # for single step prefill
304+ with set_forward_context (attn_metadata , self .vllm_config ):
305+ return model_executable .generate_proposals (
306+ input_ids = input_tokens ,
307+ positions = input_positions ,
308+ kv_caches = kv_caches ,
309+ attn_metadata = attn_metadata ,
310+ sampling_metadata = model_input .sampling_metadata ,
311+ ** model_execute_kwargs ,
312+ )
313+ # model_execute_kwargs["spec_step_idx"] = spec_step_idx
314+ with set_forward_context (attn_metadata , self .vllm_config ):
288315 hidden_states = model_executable (
289- input_ids = model_input .input_tokens ,
290- positions = model_input .input_positions ,
316+ input_ids = input_tokens ,
317+ positions = input_positions ,
318+ kv_caches = kv_caches ,
319+ attn_metadata = attn_metadata ,
291320 intermediate_tensors = intermediate_tensors ,
292321 ** MultiModalKwargs .as_kwargs (multi_modal_kwargs ,
293322 device = self .device ),
294323 ** model_execute_kwargs ,
295324 )
296325
297326 # Compute the logits.
298- logits = self .model .compute_logits (hidden_states ,
299- model_input .sampling_metadata ,
300- ** compute_logits_kwargs )
327+ logits = self .model .compute_logits (
328+ hidden_states , # do not sample for the previous tokens
329+ model_input .sampling_metadata ,
330+ ** compute_logits_kwargs )
301331 if not self .is_driver_worker :
302332 return []
303333 # Sample the next token.
304334 output = self .model .sample (
305335 logits = logits ,
306336 sampling_metadata = model_input .sampling_metadata ,
307337 )
338+ # TODO: do sampling/compute logits for the last token only
339+ if self .mtp :
340+ # return last token only for each step for MTP
341+ output = self .model .get_last_sample_output (
342+ output , attn_metadata )
343+ input_tokens = self .model .get_next_layer_input (
344+ input_tokens , attn_metadata , output )
308345 outputs .append (output )
309346
310- if model_input .attn_metadata .num_prefills == 0 \
347+ if not self . mtp and model_input .attn_metadata .num_prefills == 0 \
311348 and self .indices_of_seq_with_bonus_tokens is not None :
312349 assert output .sampled_token_ids is not None
313350 # output.sampled_token_ids should be of shape (num_seqs, 1)
@@ -327,7 +364,7 @@ def execute_model(
327364 count += 1
328365
329366 # Prepare inputs for the next step
330- if step != num_steps - 1 :
367+ if step != num_steps - 1 and not self . mtp :
331368 model_input = self ._gpu_advance_step (model_input , outputs [- 1 ])
332369
333370 return outputs
0 commit comments