-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Model][Speculative Decoding] support k > 1 for MTP #13805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
luccafong
wants to merge
1
commit into
vllm-project:main
Choose a base branch
from
luccafong:ds_mtp_multistep
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,11 @@ | |
| # vllm_flash_attn is not installed, try the ROCm FA metadata | ||
| from vllm.attention.backends.rocm_flash_attn import ( | ||
| ROCmFlashAttentionMetadata as FlashAttentionMetadata) | ||
| try: | ||
| from vllm.attention.backends.triton_mla import TritonMLAMetadata | ||
| except (ModuleNotFoundError, ImportError): | ||
| TritonMLAMetadata = FlashAttentionMetadata | ||
|
|
||
| except (ModuleNotFoundError, ImportError) as err: | ||
| raise RuntimeError( | ||
| "Draft model speculative decoding currently only supports " | ||
|
|
@@ -57,7 +62,7 @@ | |
| "return_hidden_states is not supported for TP1DraftModelRunner." | ||
| ) | ||
| super().__init__(model_runner) | ||
|
|
||
| self.mtp = False | ||
| self.indices_of_seq_with_bonus_tokens = None | ||
|
|
||
| def _update_sampling_metadata(self, sampling_metadata, num_seqs, | ||
|
|
@@ -92,7 +97,8 @@ | |
|
|
||
| # Update attn_metadata | ||
| attn_metadata = model_input.attn_metadata | ||
| assert isinstance(attn_metadata, FlashAttentionMetadata) | ||
| assert isinstance(attn_metadata, | ||
| (FlashAttentionMetadata, TritonMLAMetadata)) | ||
|
|
||
| attn_metadata.advance_step(model_input, sampled_token_ids, | ||
| self.block_size, num_seqs, num_queries) | ||
|
|
@@ -193,6 +199,7 @@ | |
| # iteration invokes this function only once | ||
| # (Look at multi-step-worker code) | ||
| is_fallback = num_steps == 1 | ||
| self.mtp = self.model.config.model_type == "deepseek_mtp" | ||
| if not is_fallback: | ||
| # Since we do not broadcast data inside execute_model anymore, | ||
| # we need to figure out the best way to support TP > 1 in this | ||
|
|
@@ -269,6 +276,9 @@ | |
| hidden_states = previous_hidden_states | ||
|
|
||
| outputs: List[SamplerOutput] = [] | ||
| input_tokens = model_input.input_tokens | ||
| input_positions = model_input.input_positions | ||
| attn_metadata = model_input.attn_metadata | ||
| for step in range(num_steps): | ||
| multi_modal_kwargs = model_input.multi_modal_kwargs or {} | ||
|
|
||
|
|
@@ -277,37 +287,64 @@ | |
|
|
||
| compute_logits_kwargs = {} | ||
| # Run model | ||
| if hasattr(self.model.config, "num_nextn_predict_layers"): | ||
| spec_step_idx = kwargs.get("spec_step_idx", 0) | ||
| if self.model_config.requires_multi_step_decode: | ||
| # for DeepSeek MTP only to use the corresponding layer for | ||
| # each step | ||
| spec_step_idx = kwargs.get("spec_step_idx", step) | ||
| model_execute_kwargs["spec_step_idx"] = spec_step_idx | ||
| compute_logits_kwargs["spec_step_idx"] = spec_step_idx | ||
| with set_forward_context(model_input.attn_metadata, | ||
| self.vllm_config): | ||
| if spec_step_idx >= 0: | ||
| model_execute_kwargs["spec_step_idx"] = spec_step_idx | ||
| compute_logits_kwargs["spec_step_idx"] = spec_step_idx | ||
|
|
||
| graph_batch_size = model_input.input_tokens.shape[0] | ||
| graph_idx = self.parallel_config.pipeline_parallel_size * spec_step_idx + model_input.virtual_engine | ||
| model_executable = self.graph_runners[graph_idx][graph_batch_size] | ||
| elif not use_cuda_graph: | ||
| # for single step prefill | ||
| with set_forward_context(attn_metadata, self.vllm_config): | ||
| return model_executable.generate_proposals( | ||
| input_ids=input_tokens, | ||
| positions=input_positions, | ||
| kv_caches=kv_caches, | ||
| attn_metadata=attn_metadata, | ||
| sampling_metadata=model_input.sampling_metadata, | ||
| **model_execute_kwargs, | ||
| ) | ||
| # model_execute_kwargs["spec_step_idx"] = spec_step_idx | ||
| with set_forward_context(attn_metadata, self.vllm_config): | ||
| hidden_states = model_executable( | ||
| input_ids=model_input.input_tokens, | ||
| positions=model_input.input_positions, | ||
| input_ids=input_tokens, | ||
| positions=input_positions, | ||
| kv_caches=kv_caches, | ||
| attn_metadata=attn_metadata, | ||
| intermediate_tensors=intermediate_tensors, | ||
| **MultiModalKwargs.as_kwargs(multi_modal_kwargs, | ||
| device=self.device), | ||
| **model_execute_kwargs, | ||
| ) | ||
|
|
||
| # Compute the logits. | ||
| logits = self.model.compute_logits(hidden_states, | ||
| model_input.sampling_metadata, | ||
| **compute_logits_kwargs) | ||
| logits = self.model.compute_logits( | ||
| hidden_states, # do not sample for the previous tokens | ||
| model_input.sampling_metadata, | ||
| **compute_logits_kwargs) | ||
| if not self.is_driver_worker: | ||
| return [] | ||
| # Sample the next token. | ||
| output = self.model.sample( | ||
| logits=logits, | ||
| sampling_metadata=model_input.sampling_metadata, | ||
| ) | ||
| # TODO: do sampling/compute logits for the last token only | ||
| if self.mtp: | ||
| # return last token only for each step for MTP | ||
| output = self.model.get_last_sample_output( | ||
| output, attn_metadata) | ||
| input_tokens = self.model.get_next_layer_input( | ||
| input_tokens, attn_metadata, output) | ||
| outputs.append(output) | ||
|
|
||
| if model_input.attn_metadata.num_prefills == 0 \ | ||
| if not self.mtp and model_input.attn_metadata.num_prefills == 0 \ | ||
| and self.indices_of_seq_with_bonus_tokens is not None: | ||
| assert output.sampled_token_ids is not None | ||
| # output.sampled_token_ids should be of shape (num_seqs, 1) | ||
|
|
@@ -327,7 +364,7 @@ | |
| count += 1 | ||
|
|
||
| # Prepare inputs for the next step | ||
| if step != num_steps - 1: | ||
| if step != num_steps - 1 and not self.mtp: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the multi-step logic omitted here, and |
||
| model_input = self._gpu_advance_step(model_input, outputs[-1]) | ||
|
|
||
| return outputs | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this block skipped?