Skip to content
Merged
11 changes: 5 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,13 +1977,12 @@ def maybe_create_spec_config(
if num_speculative_tokens is None:
# Default to max value defined in draft model config.
num_speculative_tokens = n_predict
elif num_speculative_tokens > n_predict:
# Verify provided value doesn't exceed the maximum
# supported by the draft model.
elif num_speculative_tokens > n_predict and \
num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
"This speculative model supports a maximum of "
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")
f"{num_speculative_tokens=} must be divisible by "
f"{n_predict=}")

speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
Expand Down
14 changes: 9 additions & 5 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return self.shared_head(hidden_states)
return hidden_states


class DeepSeekMultiTokenPredictor(nn.Module):
Expand Down Expand Up @@ -121,12 +121,13 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
spec_step_idx,
current_step_idx,
)

def compute_logits(
Expand All @@ -135,9 +136,12 @@ def compute_logits(
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
current_step_idx)]
logits = self.logits_processor(mtp_layer.shared_head.head,
hidden_states, sampling_metadata)
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits


Expand Down
11 changes: 4 additions & 7 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""

def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(model_runner)

self.indices_of_seq_with_bonus_tokens = None
Expand Down Expand Up @@ -153,7 +147,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
return False

# TODO: Add support for other attn backends
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
return False

# TODO: Add support for LORA
Expand Down Expand Up @@ -307,6 +301,9 @@ def execute_model(
)
outputs.append(output)

if self.return_hidden_states and is_fallback:
output.hidden_states = hidden_states

if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
Expand Down
17 changes: 17 additions & 0 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,16 @@ def sampler_output(
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
if expanded_request.previous_hidden_states is not None:
self.worker.model_runner.return_hidden_states = True
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._maybe_update_previous_hidden_states(
model_output, expanded_request)

self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
Expand All @@ -115,6 +119,19 @@ def sampler_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True

@staticmethod
def _maybe_update_previous_hidden_states(
model_output: SamplerOutput,
expanded_request: ExecuteModelRequest) -> None:
"""
Updates the previous hidden states in an expanded request
in-place with the hidden states from the model output.
"""
if expanded_request.previous_hidden_states is not None:
expanded_request.previous_hidden_states = HiddenStates(
model_output.hidden_states,
expanded_request.seq_group_metadata_list)

@staticmethod
def _expand_execute_model_request(
execute_model_req: ExecuteModelRequest,
Expand Down
6 changes: 3 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ def create_worker(
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
"deepseek_mtp":
if draft_tp == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the concern for using TP1DraftModelRunner for deepseek_mtp where tp > 1 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong
I don't think that TP1DraftModelRunner should be used at all for MTP. Since the advance-step is not compatible with the MLA backend, it isn't going to work for k > 1 without some significant changes. In this PR I maintain compatibility with TP1DraftModelRunner so that it can be used in the future when the advance-step is implemented, but even then I believe it will only be TP=1 as the broadcast limitations of TP1DraftModelRunner haven't been added either.

Using the plain ModelRunner for this branch makes the most sense to me , as we are using it as intended for single-step execution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to make TP1DraftModelRunner compatible for both cases, so name is probably confusing, actually changing model runner will be harder to support the real multi-step k>1 MTP. Good for this PR's use case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds challenging. Best of luck!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong QQ: does TP1DraftModelRunner work when the draft model has TP=8? If not, then this PR also has the benefit of greatly relieving the memory pressure I believe, since otherwise the full MTP module (~16GB weights) will all reside on the first device, which limit the max context length that can be supported?

Copy link
Collaborator

@luccafong luccafong Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong QQ: does TP1DraftModelRunner work when the draft model has TP=8? If not, then this PR also has the benefit of greatly relieving the memory pressure I believe, since otherwise the full MTP module (~16GB weights) will all reside on the first device, which limit the max context length that can be supported?

previous implementation support TP=8 in draft model runner for MTP in case of k=1. and we have attached the benchmark there for TP=8 #12755.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong ah this makes total sense now. thanks for the clarification!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pyc96 Yes, there is some special handling there and in a few other places. Right now multi-step is hard-coded in several ways to be specific to the Flash Attention backend. The advance_step operation handles some attention-specific state processing, and so its behaviour is specific to the attention backend being used.

I expect that much of this logic could be re-used for MLA, but this will have to be its own contribution. In the past I have tried to forcibly enable this code path for MLA, and there were several mismatches in state that would need to be handled specifically. I don't think that adding this compatibility is a trivial change. If you disable MLA, then the multi-step code is usable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett Thank you for your reply! On top of your PR, I changed

assert isinstance(attn_metadata, FlashAttentionMetadata)
to

assert isinstance(attn_metadata, (FlashAttentionMetadata, MLACommonMetadata))

and it just worked fine for drafter TP=1 case.

Actually without this change, for TP=1 case, I believe current code will throw this assertion error during request serving. Is that intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pyc96 I can confirm that it does seem to work with multi-step now. Even vllm/worker/multi_step_model_runner.py seems to be successful with the MLA backend. However, changing the allowed backends list is still required there as well, which leads me to believe this support wasn't specifically added but rather that the blocking restrictions were independently eliminated.

For this PR, I would prefer to merge sooner and have a follow-up work (cc @luccafong) add the multi-step functionality for DeepSeek with some thorough testing and sign-off.

For now, I've pushed a commit which removes the MLA backend from supports_gpu_multi_step so that the fallback path must be used. This will avoid the assertion failure you identified. If the community is confident that multi-step is safe with MLA, I am happy to turn the support back on.

if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
Expand All @@ -203,7 +202,8 @@ def create_worker(

proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
num_spec_prefill_steps = \
draft_model_config.hf_config.n_predict

proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
Expand Down
12 changes: 11 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,11 +1685,22 @@ def execute_model(
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
model_executable = self.model

Expand All @@ -1716,7 +1727,6 @@ def execute_model(
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
Expand Down