Skip to content

Commit 36fc1ec

Browse files
committed
support k > 1
Signed-off-by: Lu Fang <fanglu@fb.com>
1 parent 18e5059 commit 36fc1ec

File tree

9 files changed

+296
-80
lines changed

9 files changed

+296
-80
lines changed

vllm/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,18 @@ def is_encoder_decoder(self) -> bool:
10101010
"""Extract the HF encoder/decoder model flag."""
10111011
return is_encoder_decoder(self.hf_config)
10121012

1013+
@property
1014+
def requires_multi_step_decode(self) -> bool:
1015+
return getattr(self.hf_config, "model_type", "")=="deepseek_mtp" and \
1016+
getattr(self.hf_config, "num_nextn_predict_layers", 0) > 1
1017+
1018+
@property
1019+
def num_decode_modules(self) -> int:
1020+
if getattr(self.hf_config, "model_type", "") == "deepseek_mtp":
1021+
return getattr(self.hf_config, "num_nextn_predict_layers", 0)
1022+
else:
1023+
return 1
1024+
10131025
@property
10141026
def uses_mrope(self) -> bool:
10151027
return uses_mrope(self.hf_config)
@@ -3468,7 +3480,8 @@ def _set_cudagraph_sizes(self):
34683480
# which then becomes the max_batchsize_to_capture
34693481
larger_sizes = [
34703482
x for x in possible_sizes
3471-
if x >= self.scheduler_config.max_num_seqs
3483+
if x >= self.scheduler_config.max_num_seqs *
3484+
self.model_config.num_decode_modules
34723485
]
34733486
if larger_sizes:
34743487
max_batchsize_to_capture = larger_sizes[0]
@@ -3481,6 +3494,7 @@ def _set_cudagraph_sizes(self):
34813494
size for size in possible_sizes
34823495
if size <= max_batchsize_to_capture
34833496
]
3497+
# print(f"{batch_size_capture_list=}")
34843498
else:
34853499
batch_size_capture_list = []
34863500
if self.model_config is not None and \

vllm/engine/output_processor/multi_step.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def _process_seq_outputs(self, seq: Sequence,
185185
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
186186
# Incrementally append tokens to the sequence, as if we had only one new
187187
# token.
188+
# TODO: add an attribute here for reset, can be set at output processor
189+
seq.data.reset_new_appended_tokens()
188190
for output_token_id, output_logprob in zip(output_token_ids,
189191
output_logprobs):
190192
seq.append_token_id(

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,47 @@ def compute_logits(
175175
return self.model.compute_logits(hidden_states, sampling_metadata,
176176
spec_step_idx)
177177

178+
def generate_proposals(
179+
self,
180+
input_ids: torch.Tensor,
181+
positions: torch.Tensor,
182+
kv_caches: List[torch.Tensor],
183+
attn_metadata: AttentionMetadata,
184+
previous_hidden_states: torch.Tensor,
185+
sampling_metadata: SamplingMetadata,
186+
) -> List[SamplerOutput]:
187+
hidden_states = previous_hidden_states
188+
cur_input_ids = input_ids
189+
outputs = []
190+
for i in range(self.model.num_mtp_layers):
191+
hidden_states = self.forward(cur_input_ids,
192+
positions,
193+
kv_caches,
194+
attn_metadata,
195+
hidden_states,
196+
spec_step_idx=i)
197+
logits = self.compute_logits(hidden_states=hidden_states,
198+
sampling_metadata=sampling_metadata,
199+
spec_step_idx=i)
200+
output = self.sample(
201+
logits=logits,
202+
sampling_metadata=sampling_metadata,
203+
)
204+
outputs.append(output)
205+
cur_input_ids = self.get_next_layer_input(input_ids, attn_metadata,
206+
output)
207+
return outputs
208+
209+
def get_next_layer_input(
210+
self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata,
211+
outputs: SamplerOutput) -> Tuple[torch.Tensor, SamplerOutput]:
212+
assert outputs.sampled_token_ids is not None
213+
assert attn_metadata.query_start_loc is not None
214+
input_ids = input_ids.roll(shifts=-1, dims=0)
215+
query_end_loc = attn_metadata.query_start_loc[1:] - 1
216+
input_ids[query_end_loc] = outputs.sampled_token_ids[:, 0]
217+
return input_ids
218+
178219
def sample(
179220
self,
180221
logits: torch.Tensor,
@@ -183,6 +224,18 @@ def sample(
183224
next_tokens = self.sampler(logits, sampling_metadata)
184225
return next_tokens
185226

227+
def get_last_sample_output(
228+
self,
229+
output: SamplerOutput,
230+
attn_metadata: AttentionMetadata,
231+
) -> SamplerOutput:
232+
query_end_loc = attn_metadata.query_start_loc[1:] - 1
233+
output.sampled_token_ids = output.sampled_token_ids[query_end_loc]
234+
if output.sampled_token_probs is not None:
235+
output.sampled_token_probs = output.sampled_token_probs[
236+
query_end_loc]
237+
return output
238+
186239
def load_weights(self, weights: Iterable[Tuple[str,
187240
torch.Tensor]]) -> Set[str]:
188241
stacked_params_mapping = [

vllm/sequence.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def get_delta_and_reset(self) -> SequenceDataDelta:
365365
self._new_appended_tokens = []
366366
return delta
367367

368+
def reset_new_appended_tokens(self) -> None:
369+
self._new_appended_tokens = []
370+
368371
def apply_delta(self, delta: SequenceDataDelta):
369372
self._num_computed_tokens = delta.new_num_computed_tokens
370373
self._cumulative_logprob = delta.new_cumulative_logprob
@@ -1212,12 +1215,13 @@ class HiddenStates(msgspec.Struct, array_like=True,
12121215
# last proposed token is accepted (i.e., in case of bonus tokens). For the
12131216
# case of no bonus tokens, these are ignored.
12141217
second_last_token_hidden_states: Optional[torch.Tensor] = None
1215-
1218+
# for varseq
1219+
hidden_states_seq_indices: Optional[torch.Tensor] = None
12161220
_seq_ids: List[int] = msgspec.field(default_factory=list)
12171221

12181222
def __post_init__(self):
12191223
if self.seq_group_metadata_list is not None:
1220-
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
1224+
# TODO: add assertion for the group metadata list with var seqs
12211225
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
12221226

12231227
@property
@@ -1231,8 +1235,20 @@ def update(self,
12311235
"""Update hidden states from target model invocation. Only used for
12321236
decode steps"""
12331237
assert len(seq_group_metadata_list) == len(hidden_states)
1234-
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
1238+
last_seq_indice = len(self._seq_ids)
1239+
new_seq_ids = get_all_seq_ids(seq_group_metadata_list)
1240+
self._seq_ids.extend(new_seq_ids)
12351241
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
1242+
if self.hidden_states_seq_indices is not None:
1243+
updated_indices = list(range(last_seq_indice, len(self._seq_ids)))
1244+
# assume new updated are hidden states from
1245+
# prefill which is always length of 1
1246+
new_seq_indices = torch.tensor(
1247+
updated_indices, device=self.hidden_states_seq_indices.device)
1248+
self.hidden_states_seq_indices = torch.concat([
1249+
self.hidden_states_seq_indices,
1250+
new_seq_indices,
1251+
])
12361252

12371253
if self.second_last_token_hidden_states is not None:
12381254
# Adding dummy hidden_states to this to maintain same shape
@@ -1255,10 +1271,17 @@ def prune(self,
12551271
if seq_ids != self._seq_ids:
12561272
# Batch contents changed - prune removed sequences.
12571273
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1258-
self.hidden_states = self.hidden_states[index]
1259-
if self.second_last_token_hidden_states is not None:
1260-
self.second_last_token_hidden_states = self\
1261-
.second_last_token_hidden_states[index]
1274+
if self.hidden_states_seq_indices is not None:
1275+
target_indices_tensor = torch.tensor(
1276+
index, device=self.hidden_states_seq_indices.device)
1277+
index = (self.hidden_states_seq_indices[..., None] ==
1278+
target_indices_tensor).any(dim=-1)
1279+
self.hidden_states = self.hidden_states[index]
1280+
else:
1281+
self.hidden_states = self.hidden_states[index]
1282+
if self.second_last_token_hidden_states is not None:
1283+
self.second_last_token_hidden_states = self\
1284+
.second_last_token_hidden_states[index]
12621285
self._seq_ids = seq_ids
12631286

12641287
def expand_with_bonus_tokens(

vllm/spec_decode/draft_model_runner.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# vllm_flash_attn is not installed, try the ROCm FA metadata
1515
from vllm.attention.backends.rocm_flash_attn import (
1616
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
17+
try:
18+
from vllm.attention.backends.triton_mla import TritonMLAMetadata
19+
except (ModuleNotFoundError, ImportError):
20+
TritonMLAMetadata = FlashAttentionMetadata
21+
1722
except (ModuleNotFoundError, ImportError) as err:
1823
raise RuntimeError(
1924
"Draft model speculative decoding currently only supports "
@@ -57,7 +62,7 @@ def __init__(self, model_runner: ModelRunnerBase):
5762
"return_hidden_states is not supported for TP1DraftModelRunner."
5863
)
5964
super().__init__(model_runner)
60-
65+
self.mtp = False
6166
self.indices_of_seq_with_bonus_tokens = None
6267

6368
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
@@ -92,7 +97,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
9297

9398
# Update attn_metadata
9499
attn_metadata = model_input.attn_metadata
95-
assert isinstance(attn_metadata, FlashAttentionMetadata)
100+
assert isinstance(attn_metadata,
101+
(FlashAttentionMetadata, TritonMLAMetadata))
96102

97103
attn_metadata.advance_step(model_input, sampled_token_ids,
98104
self.block_size, num_seqs, num_queries)
@@ -193,6 +199,7 @@ def execute_model(
193199
# iteration invokes this function only once
194200
# (Look at multi-step-worker code)
195201
is_fallback = num_steps == 1
202+
self.mtp = self.model.config.model_type == "deepseek_mtp"
196203
if not is_fallback:
197204
# Since we do not broadcast data inside execute_model anymore,
198205
# we need to figure out the best way to support TP > 1 in this
@@ -269,6 +276,9 @@ def execute_model(
269276
hidden_states = previous_hidden_states
270277

271278
outputs: List[SamplerOutput] = []
279+
input_tokens = model_input.input_tokens
280+
input_positions = model_input.input_positions
281+
attn_metadata = model_input.attn_metadata
272282
for step in range(num_steps):
273283
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
274284

@@ -277,37 +287,64 @@ def execute_model(
277287

278288
compute_logits_kwargs = {}
279289
# Run model
280-
if hasattr(self.model.config, "num_nextn_predict_layers"):
290+
spec_step_idx = kwargs.get("spec_step_idx", 0)
291+
if self.model_config.requires_multi_step_decode:
281292
# for DeepSeek MTP only to use the corresponding layer for
282293
# each step
283294
spec_step_idx = kwargs.get("spec_step_idx", step)
284-
model_execute_kwargs["spec_step_idx"] = spec_step_idx
285-
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
286-
with set_forward_context(model_input.attn_metadata,
287-
self.vllm_config):
295+
if spec_step_idx >= 0:
296+
model_execute_kwargs["spec_step_idx"] = spec_step_idx
297+
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
298+
299+
graph_batch_size = model_input.input_tokens.shape[0]
300+
graph_idx = self.parallel_config.pipeline_parallel_size * spec_step_idx + model_input.virtual_engine
301+
model_executable = self.graph_runners[graph_idx][graph_batch_size]
302+
elif not use_cuda_graph:
303+
# for single step prefill
304+
with set_forward_context(attn_metadata, self.vllm_config):
305+
return model_executable.generate_proposals(
306+
input_ids=input_tokens,
307+
positions=input_positions,
308+
kv_caches=kv_caches,
309+
attn_metadata=attn_metadata,
310+
sampling_metadata=model_input.sampling_metadata,
311+
**model_execute_kwargs,
312+
)
313+
# model_execute_kwargs["spec_step_idx"] = spec_step_idx
314+
with set_forward_context(attn_metadata, self.vllm_config):
288315
hidden_states = model_executable(
289-
input_ids=model_input.input_tokens,
290-
positions=model_input.input_positions,
316+
input_ids=input_tokens,
317+
positions=input_positions,
318+
kv_caches=kv_caches,
319+
attn_metadata=attn_metadata,
291320
intermediate_tensors=intermediate_tensors,
292321
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
293322
device=self.device),
294323
**model_execute_kwargs,
295324
)
296325

297326
# Compute the logits.
298-
logits = self.model.compute_logits(hidden_states,
299-
model_input.sampling_metadata,
300-
**compute_logits_kwargs)
327+
logits = self.model.compute_logits(
328+
hidden_states, # do not sample for the previous tokens
329+
model_input.sampling_metadata,
330+
**compute_logits_kwargs)
301331
if not self.is_driver_worker:
302332
return []
303333
# Sample the next token.
304334
output = self.model.sample(
305335
logits=logits,
306336
sampling_metadata=model_input.sampling_metadata,
307337
)
338+
# TODO: do sampling/compute logits for the last token only
339+
if self.mtp:
340+
# return last token only for each step for MTP
341+
output = self.model.get_last_sample_output(
342+
output, attn_metadata)
343+
input_tokens = self.model.get_next_layer_input(
344+
input_tokens, attn_metadata, output)
308345
outputs.append(output)
309346

310-
if model_input.attn_metadata.num_prefills == 0 \
347+
if not self.mtp and model_input.attn_metadata.num_prefills == 0 \
311348
and self.indices_of_seq_with_bonus_tokens is not None:
312349
assert output.sampled_token_ids is not None
313350
# output.sampled_token_ids should be of shape (num_seqs, 1)
@@ -327,7 +364,7 @@ def execute_model(
327364
count += 1
328365

329366
# Prepare inputs for the next step
330-
if step != num_steps - 1:
367+
if step != num_steps - 1 and not self.mtp:
331368
model_input = self._gpu_advance_step(model_input, outputs[-1])
332369

333370
return outputs

vllm/spec_decode/multi_step_worker.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def set_should_modify_greedy_probs_inplace(self) -> None:
5656
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
5757
True)
5858

59+
@property
60+
def has_mtp_runner(self) -> bool:
61+
return getattr(self.model_runner, "mtp", False)
62+
5963
@torch.inference_mode()
6064
def sampler_output(
6165
self,
@@ -74,10 +78,13 @@ def sampler_output(
7478
# Expand the batch for sequences with a bonus token.
7579
# Perform a forward pass on the expanded batch and filter the
7680
# response to retain only the original sequences' responses.
77-
expanded_request, indices_of_seq_with_bonus_tokens =\
78-
self._expand_execute_model_request(
79-
execute_model_req, seq_ids_with_bonus_token_in_last_step)
80-
81+
if self.has_mtp_runner:
82+
expanded_request, indices_of_seq_with_bonus_tokens =\
83+
execute_model_req, []
84+
else:
85+
expanded_request, indices_of_seq_with_bonus_tokens =\
86+
self._expand_execute_model_request(
87+
execute_model_req, seq_ids_with_bonus_token_in_last_step)
8188
# Run model sample_len times.
8289
model_outputs: List[SamplerOutput] = []
8390
if current_platform.is_cuda_alike() and isinstance(
@@ -109,10 +116,14 @@ def sampler_output(
109116
model_outputs.append(model_output)
110117

111118
# move indices to device to avoid stream sync
112-
indices_of_seq_with_bonus_tokens = torch.tensor(
113-
indices_of_seq_with_bonus_tokens, device=self.device)
114-
filtered_model_outputs = self._filter_model_output(
115-
model_outputs, indices_of_seq_with_bonus_tokens)
119+
if self.has_mtp_runner:
120+
filtered_model_outputs = model_outputs
121+
else:
122+
indices_of_seq_with_bonus_tokens = torch.tensor(
123+
indices_of_seq_with_bonus_tokens, device=self.device)
124+
filtered_model_outputs = self._filter_model_output(
125+
model_outputs, indices_of_seq_with_bonus_tokens)
126+
116127
return filtered_model_outputs, True
117128

118129
@staticmethod

0 commit comments

Comments
 (0)