diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index a6eb628f9198f..7744b2640fe94 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -7,6 +7,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -85,6 +86,7 @@ def test_same_output_for_single_step(): block_size, num_gpu_blocks, seed, + model_runner_cls=TP1DraftModelRunner, ) worker = create_worker( Worker, @@ -168,6 +170,7 @@ def test_same_output_for_multi_step(): block_size, num_gpu_blocks, seed, + model_runner_cls=TP1DraftModelRunner, ) worker = create_worker( diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index ce5b347832c30..68802f0b8468d 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -14,6 +14,7 @@ SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine +from vllm.worker.model_runner import ModelRunner from vllm.worker.worker import Worker T = TypeVar("T", bound=Worker) @@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T], num_gpu_blocks: int, seed: int, is_driver_worker: bool = True, - enforce_eager: bool = True) -> T: + enforce_eager: bool = True, + model_runner_cls: Optional[ModelRunner] = None) -> T: engine_args = EngineArgs( model=model_name, seed=seed, @@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T], rank=0, distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, + model_runner_cls=model_runner_cls, ) worker.init_device() diff --git a/vllm/sequence.py b/vllm/sequence.py index 0925d15461fd5..19a75c5a353a1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -903,6 +903,8 @@ class ExecuteModelRequest: running_queue_size: int = 0 # Optional hidden states from prior step. previous_hidden_states: Optional[HiddenStates] = None + # The number of forward steps to run. + num_steps: int = 1 def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -916,4 +918,5 @@ def clone( num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, + num_steps=self.num_steps, ) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py new file mode 100644 index 0000000000000..f30d29376121a --- /dev/null +++ b/vllm/spec_decode/draft_model_runner.py @@ -0,0 +1,170 @@ +from typing import List, Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, + ModelRunner) + +logger = init_logger(__name__) + + +class TP1DraftModelRunner(ModelRunner): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + This runner is still under development so there's no performance gain + at this moment. Currently we adopt a temporary solution that caches the + seq_group_metadata_list for multi-step execution, so that we can + leverage existing prepare_model_input to be compatible with the current + execution flow, but we plan to remove this cache and avoid calling + prepare_model_input in execute_model at all. + + The detail development plan includes: + 1. Use "update_model_input" to update existing model_input without + creating a new one. + 2. Improve the performance of "update_model_input" with a GPU kernel. + 3. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + return_hidden_states: bool = False, + ): + if return_hidden_states: + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + + super().__init__( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + load_config=load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config, + return_hidden_states=return_hidden_states, + ) + + # TODO: Remove this cache when we are able to update model_input + # directly in advance_step. + self.cached_seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInputForGPUWithSamplingMetadata: + """A temporary solution that caches the seq_group_metadata_list + for multi-step execution. + TODO: In-place update model_input and remove this function. + """ + self.cached_seq_group_metadata_list = seq_group_metadata_list + return super().prepare_model_input(seq_group_metadata_list) + + def update_model_input( + self, model_input: ModelInputForGPUWithSamplingMetadata, + last_output: SamplerOutput + ) -> ModelInputForGPUWithSamplingMetadata: + """Prepare the model inputs for the next step. + TODO: In-place update model_input instead of calling + prepare_model_input. + """ + + # Append the output token to the sequence data. + assert self.cached_seq_group_metadata_list is not None + for seq_group_metadata, sequence_group_outputs in zip( + self.cached_seq_group_metadata_list, last_output.outputs): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + + seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) + + return self.prepare_model_input(self.cached_seq_group_metadata_list) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + outputs: List[SamplerOutput] = [] + for step in range(num_steps): + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Sample the next token. + outputs.append( + self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + )) + + # Prepare the inputs for the next step. + if step != num_steps - 1: + model_input = self.update_model_input(model_input, outputs[-1]) + + return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index e469fd7c3a160..c1a02e1d32e85 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,6 +6,7 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, SequenceGroupMetadata) +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -67,22 +68,24 @@ def sampler_output( copied_execute_model_req = execute_model_req.clone( copied_seq_group_metadata_list) - # Assert enough KV space for sample_len tokens per sequence. - self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, - sample_len) - # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - for _ in range(sample_len): - model_output: List[SamplerOutput] = super().execute_model( + if isinstance(self.model_runner, TP1DraftModelRunner): + copied_execute_model_req.num_steps = sample_len + model_outputs = self.execute_model( execute_model_req=copied_execute_model_req) - assert (len(model_output) == 1 - ), "composing multistep workers not supported" - model_output = model_output[0] - - self._append_new_tokens(model_output, - copied_seq_group_metadata_list) - model_outputs.append(model_output) + else: + # TODO: Remove this branch once DraftModelRunner supports TP>1. + for _ in range(sample_len): + model_output: List[SamplerOutput] = super().execute_model( + execute_model_req=copied_execute_model_req) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] + + self._append_new_tokens(model_output, + copied_seq_group_metadata_list) + model_outputs.append(model_output) return model_outputs, True diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 5089e3dd556e9..f1e64cae8fc5b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -11,6 +11,7 @@ HiddenStates, SamplerOutput, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector @@ -117,6 +118,8 @@ def create_worker( draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size + if draft_tp == 1: + draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker, draft_tp, target_tp) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index e3464c0d3900c..aef7d5b604b27 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -362,7 +362,12 @@ def execute_model( self, model_input: CPUModelInput, kv_caches: List[torch.Tensor], - ) -> Optional[SamplerOutput]: + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + model_executable = self.model execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -382,11 +387,11 @@ def execute_model( # Only perform sampling in the driver worker. if not self.is_driver_worker: - return None + return [] # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - return output + return [output] diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 3c8dfa2c6d8df..272917c7272df 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -57,7 +57,12 @@ def execute_model( self, model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], - ) -> Optional[PoolerOutput]: + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError( + "EmbeddingModelRunner does not support multi-step execution.") + if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -91,10 +96,12 @@ def execute_model( # Only perform pooling in the driver worker. if not self.is_driver_worker: - return None + return [] - return self.model.pooler(hidden_states=hidden_states, - pooling_metadata=model_input.pooling_metadata) + return [ + self.model.pooler(hidden_states=hidden_states, + pooling_metadata=model_input.pooling_metadata) + ] def make_model_input_from_broadcasted_tensor_dict( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9fdb2ea5dd4e4..d71d2e0aa4a02 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -977,7 +977,11 @@ def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], - ) -> SamplerOutput: + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -1010,7 +1014,7 @@ def execute_model( # Only perform sampling in the driver worker. if not self.is_driver_worker: - return None + return [] # Sample the next token. output: SamplerOutput = self.model.sample( @@ -1026,7 +1030,7 @@ def execute_model( 0, model_input.sampling_metadata.selected_token_indices) output.hidden_states = hidden_states - return output + return [output] class CUDAGraphRunner: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 9b1706035a33e..959cfc0b9cac5 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -150,7 +150,8 @@ def execute_model( self, model_input: T, kv_caches: Optional[List[torch.Tensor]], - ) -> Optional[SamplerOutput]: + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. """ diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index fec2c97e73889..2ccf4a50a87bd 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -207,7 +207,12 @@ def execute_model( self, model_input: ModelInputForNeuron, kv_caches: Optional[List[torch.Tensor]] = None, - ) -> Optional[SamplerOutput]: + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -223,7 +228,7 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - return output + return [output] @property def vocab_size(self) -> int: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2c70c1f917a0d..04364eab02f3f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -422,7 +422,12 @@ def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> SamplerOutput: + num_steps: int = 1, + ) -> List[SamplerOutput]: + if num_steps > 1: + raise ValueError( + "TPUModelRunner does not support multi-step execution.") + assert seq_group_metadata_list is not None assert len(seq_group_metadata_list) > 0 if seq_group_metadata_list[0].is_prompt: @@ -440,7 +445,7 @@ def execute_model( else: sampler_outputs = self._execute_model(seq_group_metadata_list, kv_caches) - return SamplerOutput(sampler_outputs) + return [SamplerOutput(sampler_outputs)] class ModelWrapper(nn.Module): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e1944a4f1d636..156d5278a292a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -45,6 +45,7 @@ def __init__( vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config @@ -78,7 +79,9 @@ def __init__( "mlp_speculator") else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if self.model_config.embedding_mode: + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: ModelRunnerClass = EmbeddingModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1df60eb1f38ce..d867e15bdf82d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -228,11 +228,13 @@ def execute_model( model_input: ModelRunnerInputBase = ( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list)) + num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update( model_input.as_broadcastable_tensor_dict()) + broadcast_data["num_steps"] = num_steps broadcast_tensor_dict(broadcast_data, src=0) else: assert self.do_metadata_broadcast @@ -240,6 +242,7 @@ def execute_model( if not broadcast_data: return None + num_steps = broadcast_data.pop("num_steps") worker_input = WorkerInput.from_broadcasted_tensor_dict( broadcast_data) model_input = ( @@ -252,10 +255,8 @@ def execute_model( if worker_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(model_input, self.kv_cache) - # Worker only supports single-step execution. Wrap the output in a - # list to conform to interface. - return [output] + return self.model_runner.execute_model(model_input, self.kv_cache, + num_steps) class WorkerWrapperBase: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index d9124a788a69d..99fd7da5edda5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -334,7 +334,12 @@ def execute_model( self, model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], - ) -> Optional[SamplerOutput]: + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "XPUModelRunner does not support multi-step execution.") + model_executable = self.model execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -354,14 +359,14 @@ def execute_model( # Only perform sampling in the driver worker. if not self.is_driver_worker: - return None + return [] # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - return output + return [output] def _prepare_prompt( self,