diff --git a/vllm/config.py b/vllm/config.py index 8e1ce87438af..bed62a4fec2c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1010,6 +1010,18 @@ def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" return is_encoder_decoder(self.hf_config) + @property + def requires_multi_step_decode(self) -> bool: + return getattr(self.hf_config, "model_type", "")=="deepseek_mtp" and \ + getattr(self.hf_config, "num_nextn_predict_layers", 0) > 1 + + @property + def num_decode_modules(self) -> int: + if getattr(self.hf_config, "model_type", "") == "deepseek_mtp": + return getattr(self.hf_config, "num_nextn_predict_layers", 0) + else: + return 1 + @property def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) @@ -3468,7 +3480,8 @@ def _set_cudagraph_sizes(self): # which then becomes the max_batchsize_to_capture larger_sizes = [ x for x in possible_sizes - if x >= self.scheduler_config.max_num_seqs + if x >= self.scheduler_config.max_num_seqs * + self.model_config.num_decode_modules ] if larger_sizes: max_batchsize_to_capture = larger_sizes[0] @@ -3481,6 +3494,7 @@ def _set_cudagraph_sizes(self): size for size in possible_sizes if size <= max_batchsize_to_capture ] + # print(f"{batch_size_capture_list=}") else: batch_size_capture_list = [] if self.model_config is not None and \ diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8ceef855e020..07d9be6ac3f2 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -185,6 +185,8 @@ def _process_seq_outputs(self, seq: Sequence, is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. + # TODO: add an attribute here for reset, can be set at output processor + seq.data.reset_new_appended_tokens() for output_token_id, output_logprob in zip(output_token_ids, output_logprobs): seq.append_token_id( diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index cac1b2b3b11c..498e9849a610 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -175,6 +175,47 @@ def compute_logits( return self.model.compute_logits(hidden_states, sampling_metadata, spec_step_idx) + def generate_proposals( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + hidden_states = previous_hidden_states + cur_input_ids = input_ids + outputs = [] + for i in range(self.model.num_mtp_layers): + hidden_states = self.forward(cur_input_ids, + positions, + kv_caches, + attn_metadata, + hidden_states, + spec_step_idx=i) + logits = self.compute_logits(hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + spec_step_idx=i) + output = self.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + outputs.append(output) + cur_input_ids = self.get_next_layer_input(input_ids, attn_metadata, + output) + return outputs + + def get_next_layer_input( + self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, + outputs: SamplerOutput) -> Tuple[torch.Tensor, SamplerOutput]: + assert outputs.sampled_token_ids is not None + assert attn_metadata.query_start_loc is not None + input_ids = input_ids.roll(shifts=-1, dims=0) + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + input_ids[query_end_loc] = outputs.sampled_token_ids[:, 0] + return input_ids + def sample( self, logits: torch.Tensor, @@ -183,6 +224,18 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def get_last_sample_output( + self, + output: SamplerOutput, + attn_metadata: AttentionMetadata, + ) -> SamplerOutput: + query_end_loc = attn_metadata.query_start_loc[1:] - 1 + output.sampled_token_ids = output.sampled_token_ids[query_end_loc] + if output.sampled_token_probs is not None: + output.sampled_token_probs = output.sampled_token_probs[ + query_end_loc] + return output + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/sequence.py b/vllm/sequence.py index c0425ba33c9a..3045632a51cb 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -365,6 +365,9 @@ def get_delta_and_reset(self) -> SequenceDataDelta: self._new_appended_tokens = [] return delta + def reset_new_appended_tokens(self) -> None: + self._new_appended_tokens = [] + def apply_delta(self, delta: SequenceDataDelta): self._num_computed_tokens = delta.new_num_computed_tokens self._cumulative_logprob = delta.new_cumulative_logprob @@ -1212,12 +1215,13 @@ class HiddenStates(msgspec.Struct, array_like=True, # last proposed token is accepted (i.e., in case of bonus tokens). For the # case of no bonus tokens, these are ignored. second_last_token_hidden_states: Optional[torch.Tensor] = None - + # for varseq + hidden_states_seq_indices: Optional[torch.Tensor] = None _seq_ids: List[int] = msgspec.field(default_factory=list) def __post_init__(self): if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) + # TODO: add assertion for the group metadata list with var seqs self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) @property @@ -1231,8 +1235,20 @@ def update(self, """Update hidden states from target model invocation. Only used for decode steps""" assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + last_seq_indice = len(self._seq_ids) + new_seq_ids = get_all_seq_ids(seq_group_metadata_list) + self._seq_ids.extend(new_seq_ids) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + if self.hidden_states_seq_indices is not None: + updated_indices = list(range(last_seq_indice, len(self._seq_ids))) + # assume new updated are hidden states from + # prefill which is always length of 1 + new_seq_indices = torch.tensor( + updated_indices, device=self.hidden_states_seq_indices.device) + self.hidden_states_seq_indices = torch.concat([ + self.hidden_states_seq_indices, + new_seq_indices, + ]) if self.second_last_token_hidden_states is not None: # Adding dummy hidden_states to this to maintain same shape @@ -1255,10 +1271,17 @@ def prune(self, if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] + if self.hidden_states_seq_indices is not None: + target_indices_tensor = torch.tensor( + index, device=self.hidden_states_seq_indices.device) + index = (self.hidden_states_seq_indices[..., None] == + target_indices_tensor).any(dim=-1) + self.hidden_states = self.hidden_states[index] + else: + self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] self._seq_ids = seq_ids def expand_with_bonus_tokens( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index c54e6abe18d7..c01e8a22859f 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -14,6 +14,11 @@ # vllm_flash_attn is not installed, try the ROCm FA metadata from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) + try: + from vllm.attention.backends.triton_mla import TritonMLAMetadata + except (ModuleNotFoundError, ImportError): + TritonMLAMetadata = FlashAttentionMetadata + except (ModuleNotFoundError, ImportError) as err: raise RuntimeError( "Draft model speculative decoding currently only supports " @@ -57,7 +62,7 @@ def __init__(self, model_runner: ModelRunnerBase): "return_hidden_states is not supported for TP1DraftModelRunner." ) super().__init__(model_runner) - + self.mtp = False self.indices_of_seq_with_bonus_tokens = None def _update_sampling_metadata(self, sampling_metadata, num_seqs, @@ -92,7 +97,8 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase, # Update attn_metadata attn_metadata = model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert isinstance(attn_metadata, + (FlashAttentionMetadata, TritonMLAMetadata)) attn_metadata.advance_step(model_input, sampled_token_ids, self.block_size, num_seqs, num_queries) @@ -193,6 +199,7 @@ def execute_model( # iteration invokes this function only once # (Look at multi-step-worker code) is_fallback = num_steps == 1 + self.mtp = self.model.config.model_type == "deepseek_mtp" if not is_fallback: # Since we do not broadcast data inside execute_model anymore, # we need to figure out the best way to support TP > 1 in this @@ -269,6 +276,9 @@ def execute_model( hidden_states = previous_hidden_states outputs: List[SamplerOutput] = [] + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata for step in range(num_steps): multi_modal_kwargs = model_input.multi_modal_kwargs or {} @@ -277,17 +287,36 @@ def execute_model( compute_logits_kwargs = {} # Run model - if hasattr(self.model.config, "num_nextn_predict_layers"): + spec_step_idx = kwargs.get("spec_step_idx", 0) + if self.model_config.requires_multi_step_decode: # for DeepSeek MTP only to use the corresponding layer for # each step spec_step_idx = kwargs.get("spec_step_idx", step) - model_execute_kwargs["spec_step_idx"] = spec_step_idx - compute_logits_kwargs["spec_step_idx"] = spec_step_idx - with set_forward_context(model_input.attn_metadata, - self.vllm_config): + if spec_step_idx >= 0: + model_execute_kwargs["spec_step_idx"] = spec_step_idx + compute_logits_kwargs["spec_step_idx"] = spec_step_idx + + graph_batch_size = model_input.input_tokens.shape[0] + graph_idx = self.parallel_config.pipeline_parallel_size * spec_step_idx + model_input.virtual_engine + model_executable = self.graph_runners[graph_idx][graph_batch_size] + elif not use_cuda_graph: + # for single step prefill + with set_forward_context(attn_metadata, self.vllm_config): + return model_executable.generate_proposals( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + sampling_metadata=model_input.sampling_metadata, + **model_execute_kwargs, + ) + # model_execute_kwargs["spec_step_idx"] = spec_step_idx + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), @@ -295,9 +324,10 @@ def execute_model( ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata, - **compute_logits_kwargs) + logits = self.model.compute_logits( + hidden_states, # do not sample for the previous tokens + model_input.sampling_metadata, + **compute_logits_kwargs) if not self.is_driver_worker: return [] # Sample the next token. @@ -305,9 +335,16 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + # TODO: do sampling/compute logits for the last token only + if self.mtp: + # return last token only for each step for MTP + output = self.model.get_last_sample_output( + output, attn_metadata) + input_tokens = self.model.get_next_layer_input( + input_tokens, attn_metadata, output) outputs.append(output) - if model_input.attn_metadata.num_prefills == 0 \ + if not self.mtp and model_input.attn_metadata.num_prefills == 0 \ and self.indices_of_seq_with_bonus_tokens is not None: assert output.sampled_token_ids is not None # output.sampled_token_ids should be of shape (num_seqs, 1) @@ -327,7 +364,7 @@ def execute_model( count += 1 # Prepare inputs for the next step - if step != num_steps - 1: + if step != num_steps - 1 and not self.mtp: model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c28d413efe74..5dd0dd09ef44 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -56,6 +56,10 @@ def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( True) + @property + def has_mtp_runner(self) -> bool: + return getattr(self.model_runner, "mtp", False) + @torch.inference_mode() def sampler_output( self, @@ -74,10 +78,13 @@ def sampler_output( # Expand the batch for sequences with a bonus token. # Perform a forward pass on the expanded batch and filter the # response to retain only the original sequences' responses. - expanded_request, indices_of_seq_with_bonus_tokens =\ - self._expand_execute_model_request( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - + if self.has_mtp_runner: + expanded_request, indices_of_seq_with_bonus_tokens =\ + execute_model_req, [] + else: + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if current_platform.is_cuda_alike() and isinstance( @@ -109,10 +116,14 @@ def sampler_output( model_outputs.append(model_output) # move indices to device to avoid stream sync - indices_of_seq_with_bonus_tokens = torch.tensor( - indices_of_seq_with_bonus_tokens, device=self.device) - filtered_model_outputs = self._filter_model_output( - model_outputs, indices_of_seq_with_bonus_tokens) + if self.has_mtp_runner: + filtered_model_outputs = model_outputs + else: + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True @staticmethod diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8af71842224b..518363265e89 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -161,7 +161,7 @@ def create_worker( allow_zero_draft_token_step = True enable_lm_head_weight_load = False - num_spec_prefill_steps = 1 + next_n_prediction_steps = -1 ngram_prompt_lookup_max = ( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( @@ -203,7 +203,8 @@ def create_worker( proposer_worker = MultiStepWorker(**draft_worker_kwargs) if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = num_speculative_tokens + next_n_prediction_steps = num_speculative_tokens + proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) @@ -257,7 +258,8 @@ def create_worker( spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step, enable_lm_head_weight_load=enable_lm_head_weight_load, - num_spec_prefill_steps=num_spec_prefill_steps) + next_n_prediction_steps=next_n_prediction_steps, + ) def __init__( self, @@ -271,7 +273,7 @@ def __init__( disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, enable_lm_head_weight_load: Optional[bool] = False, - num_spec_prefill_steps: int = 1, + next_n_prediction_steps: int = -1, ): """ Create a SpecDecodeWorker. @@ -305,6 +307,7 @@ def __init__( enable_lm_head_weight_load: whether to load lm_head weight for draft models like eagle. num_spec_prefill_steps: number of speculative prefill steps to run + next_n_prediction_steps: number of speculative prefill steps to run before the speculative decoding starts. This is only used when the draft model is a deepseek_mtp model that requires prefill kv cache separately for each MTP layer. @@ -341,7 +344,7 @@ def __init__( self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs self._disable_log_stats = disable_log_stats - self._num_spec_prefill_steps = num_spec_prefill_steps + self._next_n_prediction_steps = next_n_prediction_steps def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -541,8 +544,9 @@ def execute_model( if no_spec: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) + results = self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + return results @torch.inference_mode() def start_worker_execution_loop(self) -> None: @@ -671,7 +675,6 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -692,7 +695,8 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, if self.previous_hidden_states is None and len( seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( - hidden_states, seq_group_meta_with_hidden) + hidden_states, seq_group_meta_with_hidden + ) # hidden states for T, (T+1 token) elif self.previous_hidden_states and len( seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states, @@ -702,13 +706,11 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. + execute_model_req.spec_step_idx = -1 execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) - for i in range(self._num_spec_prefill_steps): - execute_model_req.spec_step_idx = i - self.proposer_worker.execute_model(execute_model_req) - + self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else @@ -914,22 +916,39 @@ def _verify_tokens( # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] accepted_index = accepted_token_ids + 1 # Convert -1 to 0 - accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b - # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[accepted_index != - VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[accepted_index != - VLLM_INVALID_TOKEN_ID] - assert len(accepted_index) == hidden_states.shape[0] == len( - terminal_metadata) - index = accepted_index[:, None, None].expand(-1, 1, - hs_size) # b x 1 x d - second_last_token_hidden_states = hidden_states[:, -2] # b x d - hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d - # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates( - hidden_states, terminal_metadata, - second_last_token_hidden_states) + accepted_index = accepted_index.count_nonzero(dim=1) + if self._next_n_prediction_steps > 0: + hidden_states = hidden_states.reshape(-1, hs_size)[ + accepted_token_ids.reshape(-1) != VLLM_INVALID_TOKEN_ID] + seq_indices = torch.repeat_interleave( + torch.arange(0, accepted_index.shape[0]).to( + hidden_states.device), + accepted_index) # seq indices for each hidden state + self.previous_hidden_states = HiddenStates( + hidden_states, + terminal_metadata, + hidden_states_seq_indices=seq_indices, + ) + else: + # Drop non-terminal prefill chunks hidden states. + hidden_states = hidden_states[accepted_index != + VLLM_INVALID_TOKEN_ID] + accepted_index.add_(-1) + accepted_index = accepted_index[accepted_index != + VLLM_INVALID_TOKEN_ID] + assert len(accepted_index) == hidden_states.shape[0] == len( + terminal_metadata) + index = accepted_index[:, None, + None].expand(-1, 1, + hs_size) # b x 1 x d + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, + index).squeeze(1) # b x d + # Store hidden states from target model for + # subsequent decode step + self.previous_hidden_states = HiddenStates( + hidden_states, terminal_metadata, + second_last_token_hidden_states) return accepted_token_ids, logprobs def _create_output_sampler_list( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index b538923c03e7..435f22d472ef 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -5,7 +5,8 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -52,15 +53,39 @@ def get_spec_proposals( speculation. """ proposal_len = execute_model_req.num_lookahead_slots - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - + if hasattr(self._worker, "model_runner") and getattr( + self._worker.model_runner, "mtp", False): + seq_group_metadata_list_for_proposal = [] + for metadata in execute_model_req.seq_group_metadata_list: + mtp_seq_data = {} + for key, seq_data in metadata.seq_data.items(): + mtp_seq_data[key] = SequenceData.from_seqs( + seq_data.prompt_token_ids, + output_token_ids=seq_data.output_token_ids, + ) + mtp_seq_data[key].update_num_computed_tokens( + len(seq_data.prompt_token_ids) + + len(seq_data.output_token_ids) - + len(seq_data._new_appended_tokens)) + new_metadata = SequenceGroupMetadata( + request_id=metadata.request_id, + is_prompt=False, + seq_data=mtp_seq_data, + sampling_params=metadata.sampling_params, + block_tables=metadata.block_tables, + lora_request=metadata.lora_request, + ) + seq_group_metadata_list_for_proposal.append(new_metadata) + else: + seq_group_metadata_list_for_proposal =\ + execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) - + ) = self._split_by_proposal_len(seq_group_metadata_list_for_proposal, + proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative # sequences. @@ -98,7 +123,7 @@ def get_spec_proposals( # Combine speculative- and non-speculative sequences into the same # representation. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), + batch_size=len(seq_group_metadata_list_for_proposal), proposal_len=proposal_len, maybe_sampler_output=maybe_sampler_output, proposal_lens=proposal_lens, @@ -246,7 +271,6 @@ def _merge_outputs( sampler_output = maybe_sampler_output proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) - # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86dcde234f86..14de9516978c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -772,7 +772,8 @@ def _use_captured_graph(self, def _get_cuda_graph_pad_size(self, num_seqs: int, max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> int: + max_encoder_seq_len: int = 0, + total_seq_len: int = 0) -> int: """ Determine the number of padding sequences required for running in CUDA graph mode. Returns -1 if CUDA graphs cannot be used. @@ -792,6 +793,9 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len (int, optional): Greatest of all the encode sequence lengths. Defaults to 0. Used only in checking the viability of using CUDA graphs. + total_seq_len (int, optional): Total number of tokens, + if especified, it will be used to determine the number of + padding sequences. Returns: int: Returns the determined number of padding sequences. If CUDA graphs is not viable, returns -1. @@ -810,12 +814,22 @@ def _get_cuda_graph_pad_size(self, if not self._use_captured_graph(batch_size, decode_only, max_decode_seq_len, max_encoder_seq_len): + # print(f"do not use captured graph for batch_size {batch_size}") return -1 - - graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( - batch_size) - assert graph_batch_size >= batch_size - return graph_batch_size - batch_size + if total_seq_len > batch_size: + # This is a multi-step case. We need to pad the batch size to + # the current batch size so that we can run the rest of the + # steps in CUDA graph mode. + graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( + total_seq_len + ) + assert graph_batch_size >= total_seq_len + return graph_batch_size - total_seq_len + else: + graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( + batch_size) + assert graph_batch_size >= batch_size + return graph_batch_size - batch_size def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and @@ -881,7 +895,14 @@ def build(self) -> ModelInputForGPU: cuda_graph_pad_size = self._get_cuda_graph_pad_size( num_seqs=len(seq_lens), max_decode_seq_len=max_decode_seq_len, - max_encoder_seq_len=max_encoder_seq_len) + max_encoder_seq_len=max_encoder_seq_len, + total_seq_len=len(input_tokens) if self.runner.model_config.requires_multi_step_decode else 0, + ) + if self.runner.model_config.requires_multi_step_decode and cuda_graph_pad_size >=0: + batch_cuda_graph_pad_size = cuda_graph_pad_size + len(input_tokens) - len(seq_lens) + else: + batch_cuda_graph_pad_size = cuda_graph_pad_size + batch_size = len(input_tokens) if cuda_graph_pad_size != -1: @@ -918,9 +939,12 @@ def build(self) -> ModelInputForGPU: self.runner.device, self.runner.pin_memory) # Sequence and query lengths. - if cuda_graph_pad_size: - seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) + # print(f"before {seq_lens=}, {batch_cuda_graph_pad_size=}, {cuda_graph_pad_size=}") + if batch_cuda_graph_pad_size: + seq_lens.extend(itertools.repeat(1, batch_cuda_graph_pad_size)) + + # print(f"after {len(seq_lens)=}") # Attention metadata. attn_metadata = self.attn_metadata_builder.build( seq_lens, query_lens, cuda_graph_pad_size, batch_size) @@ -1030,7 +1054,7 @@ def __init__( self.vllm_config.compilation_config.max_capture_size self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ - {} for _ in range(self.parallel_config.pipeline_parallel_size) + {} for _ in range(self.parallel_config.pipeline_parallel_size * self.model_config.num_decode_modules) ] self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. @@ -1478,8 +1502,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.device) as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for virtual_engine in range( - self.parallel_config.pipeline_parallel_size): + requires_spec_decode_idx = self.model_config.requires_multi_step_decode + graph_indices = range( + self.parallel_config.pipeline_parallel_size * + self.model_config.num_decode_modules) + # [module 0: v0, v1, module 1: v0, v1, ...] + # print(f"{graph_indices=}") + for graph_idx in graph_indices: + virtual_engine = graph_idx % self.parallel_config.pipeline_parallel_size + spec_idx = int(graph_idx / self.parallel_config.pipeline_parallel_size) # Only rank 0 should print progress bar during capture cudagraph_capture_sizes = (tqdm( self.vllm_config.compilation_config. @@ -1530,8 +1561,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "memory_pool": self.graph_memory_pool, "stream": - graph_capture_context.stream + graph_capture_context.stream, } + if requires_spec_decode_idx: + capture_inputs["spec_step_idx"] = spec_idx if previous_hidden_states is not None: capture_inputs[ "previous_hidden_states"] = previous_hidden_states[: @@ -1554,7 +1587,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( + self.graph_runners[graph_idx][batch_size] = ( graph_runner) end_time = time.perf_counter()