Skip to content

Commit 9804145

Browse files
authored
[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict (#13626)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
1 parent 2e94b9c commit 9804145

File tree

6 files changed

+49
-22
lines changed

6 files changed

+49
-22
lines changed

vllm/config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,13 +1978,12 @@ def maybe_create_spec_config(
19781978
if num_speculative_tokens is None:
19791979
# Default to max value defined in draft model config.
19801980
num_speculative_tokens = n_predict
1981-
elif num_speculative_tokens > n_predict:
1982-
# Verify provided value doesn't exceed the maximum
1983-
# supported by the draft model.
1981+
elif num_speculative_tokens > n_predict and \
1982+
num_speculative_tokens % n_predict != 0:
1983+
# Ensure divisibility for MTP module reuse.
19841984
raise ValueError(
1985-
"This speculative model supports a maximum of "
1986-
f"num_speculative_tokens={n_predict}, but "
1987-
f"{num_speculative_tokens=} was provided.")
1985+
f"{num_speculative_tokens=} must be divisible by "
1986+
f"{n_predict=}")
19881987

19891988
speculative_draft_tensor_parallel_size = \
19901989
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def forward(
8787
hidden_states=hidden_states,
8888
residual=None)
8989
hidden_states = residual + hidden_states
90-
return self.shared_head(hidden_states)
90+
return hidden_states
9191

9292

9393
class DeepSeekMultiTokenPredictor(nn.Module):
@@ -121,12 +121,13 @@ def forward(
121121
inputs_embeds: Optional[torch.Tensor] = None,
122122
spec_step_idx: int = 0,
123123
) -> torch.Tensor:
124-
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
124+
current_step_idx = (spec_step_idx % self.num_mtp_layers)
125+
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
125126
input_ids,
126127
positions,
127128
previous_hidden_states,
128129
inputs_embeds,
129-
spec_step_idx,
130+
current_step_idx,
130131
)
131132

132133
def compute_logits(
@@ -135,9 +136,12 @@ def compute_logits(
135136
sampling_metadata: SamplingMetadata,
136137
spec_step_idx: int = 0,
137138
) -> torch.Tensor:
138-
mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
139+
current_step_idx = (spec_step_idx % self.num_mtp_layers)
140+
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
141+
current_step_idx)]
139142
logits = self.logits_processor(mtp_layer.shared_head.head,
140-
hidden_states, sampling_metadata)
143+
mtp_layer.shared_head(hidden_states),
144+
sampling_metadata)
141145
return logits
142146

143147

vllm/spec_decode/draft_model_runner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
5050
"""
5151

5252
def __init__(self, model_runner: ModelRunnerBase):
53-
if hasattr(
54-
model_runner,
55-
"return_hidden_states") and model_runner.return_hidden_states:
56-
raise ValueError(
57-
"return_hidden_states is not supported for TP1DraftModelRunner."
58-
)
5953
super().__init__(model_runner)
6054

6155
self.indices_of_seq_with_bonus_tokens = None
@@ -153,7 +147,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
153147
return False
154148

155149
# TODO: Add support for other attn backends
156-
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
150+
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
157151
return False
158152

159153
# TODO: Add support for LORA
@@ -307,6 +301,9 @@ def execute_model(
307301
)
308302
outputs.append(output)
309303

304+
if self.return_hidden_states and is_fallback:
305+
output.hidden_states = hidden_states
306+
310307
if model_input.attn_metadata.num_prefills == 0 \
311308
and self.indices_of_seq_with_bonus_tokens is not None:
312309
assert output.sampled_token_ids is not None

vllm/spec_decode/multi_step_worker.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,16 @@ def sampler_output(
9696
# TODO: Remove this branch once DraftModelRunner supports TP>1
9797
# and other restrictions that are part of DraftModelRunner's
9898
# supports_gpu_multi_step(..)
99+
if expanded_request.previous_hidden_states is not None:
100+
self.worker.model_runner.return_hidden_states = True
99101
for _ in range(sample_len):
100102
model_output: List[SamplerOutput] = self.worker.execute_model(
101103
execute_model_req=expanded_request)
102104
assert (len(model_output) == 1
103105
), "composing multistep workers not supported"
104106
model_output = model_output[0]
107+
self._maybe_update_previous_hidden_states(
108+
model_output, expanded_request)
105109

106110
self._append_new_tokens(
107111
model_output, expanded_request.seq_group_metadata_list,
@@ -115,6 +119,19 @@ def sampler_output(
115119
model_outputs, indices_of_seq_with_bonus_tokens)
116120
return filtered_model_outputs, True
117121

122+
@staticmethod
123+
def _maybe_update_previous_hidden_states(
124+
model_output: SamplerOutput,
125+
expanded_request: ExecuteModelRequest) -> None:
126+
"""
127+
Updates the previous hidden states in an expanded request
128+
in-place with the hidden states from the model output.
129+
"""
130+
if expanded_request.previous_hidden_states is not None:
131+
expanded_request.previous_hidden_states = HiddenStates(
132+
model_output.hidden_states,
133+
expanded_request.seq_group_metadata_list)
134+
118135
@staticmethod
119136
def _expand_execute_model_request(
120137
execute_model_req: ExecuteModelRequest,

vllm/spec_decode/spec_decode_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ def create_worker(
184184
elif draft_model_config.hf_config.model_type == "medusa":
185185
proposer_worker = MedusaWorker(**draft_worker_kwargs)
186186
else:
187-
if draft_tp == 1 or draft_model_config.hf_config.model_type ==\
188-
"deepseek_mtp":
187+
if draft_tp == 1:
189188
if current_platform.is_cuda_alike():
190189
draft_worker_kwargs[
191190
"model_runner_cls"] = TP1DraftModelRunner
@@ -203,7 +202,8 @@ def create_worker(
203202

204203
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
205204
if draft_model_config.hf_config.model_type == "deepseek_mtp":
206-
num_spec_prefill_steps = num_speculative_tokens
205+
num_spec_prefill_steps = \
206+
draft_model_config.hf_config.n_predict
207207

208208
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
209209
proposer_worker, draft_tp, target_tp)

vllm/worker/model_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1685,11 +1685,22 @@ def execute_model(
16851685
# TODO(andoorve): We can remove this once all
16861686
# virtual engines share the same kv cache.
16871687
virtual_engine = model_input.virtual_engine
1688+
previous_hidden_states = kwargs.get("previous_hidden_states")
16881689
if prefill_meta is None and decode_meta.use_cuda_graph:
16891690
assert model_input.input_tokens is not None
16901691
graph_batch_size = model_input.input_tokens.shape[0]
16911692
model_executable = self.graph_runners[virtual_engine][
16921693
graph_batch_size]
1694+
if previous_hidden_states is not None:
1695+
previous_hidden_states = torch.cat([
1696+
previous_hidden_states,
1697+
torch.empty([
1698+
graph_batch_size - previous_hidden_states.shape[0],
1699+
*previous_hidden_states.shape[1:]
1700+
],
1701+
dtype=previous_hidden_states.dtype,
1702+
device=previous_hidden_states.device)
1703+
])
16931704
else:
16941705
model_executable = self.model
16951706

@@ -1716,7 +1727,6 @@ def execute_model(
17161727
"finished_requests_ids": model_input.finished_requests_ids,
17171728
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
17181729
} if self.has_inner_state else {}
1719-
previous_hidden_states = kwargs.get("previous_hidden_states")
17201730
model_kwargs = {}
17211731
if previous_hidden_states is not None:
17221732
model_kwargs["previous_hidden_states"] = previous_hidden_states

0 commit comments

Comments
 (0)