diff --git a/vllm/config.py b/vllm/config.py index d1384c6375f3..fc55cd8b57dd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1977,13 +1977,12 @@ def maybe_create_spec_config( if num_speculative_tokens is None: # Default to max value defined in draft model config. num_speculative_tokens = n_predict - elif num_speculative_tokens > n_predict: - # Verify provided value doesn't exceed the maximum - # supported by the draft model. + elif num_speculative_tokens > n_predict and \ + num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. raise ValueError( - "This speculative model supports a maximum of " - f"num_speculative_tokens={n_predict}, but " - f"{num_speculative_tokens=} was provided.") + f"{num_speculative_tokens=} must be divisible by " + f"{n_predict=}") speculative_draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index cac1b2b3b11c..e7fde76cd0ba 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -87,7 +87,7 @@ def forward( hidden_states=hidden_states, residual=None) hidden_states = residual + hidden_states - return self.shared_head(hidden_states) + return hidden_states class DeepSeekMultiTokenPredictor(nn.Module): @@ -121,12 +121,13 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, previous_hidden_states, inputs_embeds, - spec_step_idx, + current_step_idx, ) def compute_logits( @@ -135,9 +136,12 @@ def compute_logits( sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: - mtp_layer = self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - hidden_states, sampling_metadata) + mtp_layer.shared_head(hidden_states), + sampling_metadata) return logits diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index c54e6abe18d7..bc1b3e2319d0 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -50,12 +50,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): """ def __init__(self, model_runner: ModelRunnerBase): - if hasattr( - model_runner, - "return_hidden_states") and model_runner.return_hidden_states: - raise ValueError( - "return_hidden_states is not supported for TP1DraftModelRunner." - ) super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -153,7 +147,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add support for other attn backends - if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"): + if self.attn_backend.get_name() not in ("FLASH_ATTN", ): return False # TODO: Add support for LORA @@ -307,6 +301,9 @@ def execute_model( ) outputs.append(output) + if self.return_hidden_states and is_fallback: + output.hidden_states = hidden_states + if model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c28d413efe74..d8d54918fa98 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -96,12 +96,16 @@ def sampler_output( # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) + if expanded_request.previous_hidden_states is not None: + self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] + self._maybe_update_previous_hidden_states( + model_output, expanded_request) self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, @@ -115,6 +119,19 @@ def sampler_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True + @staticmethod + def _maybe_update_previous_hidden_states( + model_output: SamplerOutput, + expanded_request: ExecuteModelRequest) -> None: + """ + Updates the previous hidden states in an expanded request + in-place with the hidden states from the model output. + """ + if expanded_request.previous_hidden_states is not None: + expanded_request.previous_hidden_states = HiddenStates( + model_output.hidden_states, + expanded_request.seq_group_metadata_list) + @staticmethod def _expand_execute_model_request( execute_model_req: ExecuteModelRequest, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 871a3aee6306..8909a41bc99f 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -184,8 +184,7 @@ def create_worker( elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: - if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ - "deepseek_mtp": + if draft_tp == 1: if current_platform.is_cuda_alike(): draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner @@ -203,7 +202,8 @@ def create_worker( proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + num_spec_prefill_steps = \ + draft_model_config.hf_config.n_predict proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a37a3168bbbc..bb2228165b52 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1685,11 +1685,22 @@ def execute_model( # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine + previous_hidden_states = kwargs.get("previous_hidden_states") if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] + if previous_hidden_states is not None: + previous_hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) else: model_executable = self.model @@ -1716,7 +1727,6 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} - previous_hidden_states = kwargs.get("previous_hidden_states") model_kwargs = {} if previous_hidden_states is not None: model_kwargs["previous_hidden_states"] = previous_hidden_states