From a29ce7d2b046a2ec944a91dff702246c8362c694 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 22 Mar 2025 08:09:51 -0700 Subject: [PATCH 1/3] draft --- vllm/v1/spec_decode/eagle_proposer.py | 11 ++++++ vllm/v1/spec_decode/interface.py | 50 ++++++++++++++++++++++++ vllm/v1/spec_decode/ngram_proposer.py | 55 ++++++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 54 ++------------------------ 4 files changed, 119 insertions(+), 51 deletions(-) create mode 100644 vllm/v1/spec_decode/eagle_proposer.py create mode 100644 vllm/v1/spec_decode/interface.py diff --git a/vllm/v1/spec_decode/eagle_proposer.py b/vllm/v1/spec_decode/eagle_proposer.py new file mode 100644 index 000000000000..317a4b8f4b85 --- /dev/null +++ b/vllm/v1/spec_decode/eagle_proposer.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class EagleProposer: + + def generate_draft_token_ids( + self, input_batch: InputBatch, sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata) -> list[list[int]]: + raise NotImplementedError("This method is not implemented yet.") diff --git a/vllm/v1/spec_decode/interface.py b/vllm/v1/spec_decode/interface.py new file mode 100644 index 000000000000..5337b3f079fd --- /dev/null +++ b/vllm/v1/spec_decode/interface.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Union + +from vllm.config import SpeculativeConfig +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle_proposer import EagleProposer +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class ProposerInterface(ABC): + """Abstract base class for speculative proposers.""" + + @abstractmethod + def generate_draft_token_ids( + self, input_batch: InputBatch, sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata) -> list[list[int]]: + """Generates draft tokens using speculative proposal strategy. + NOTE: This function will change the input_batch by writing + proposed tokens to token_ids_cpu. + Args: + input_batch: Contains input data and sequences metadata + sampled_token_ids: Already sampled tokens from previous steps + sampling_metadata: Additional sampling parameters and + constraints + + Returns: + List of draft token IDs for each sequence in + the input batch. + """ + raise NotImplementedError + + +def create_proposer( + speculative_config: SpeculativeConfig +) -> Union[NgramProposer, EagleProposer]: + """Factory function for creating proposer instances.""" + + if speculative_config.type == "ngram": + return NgramProposer(n=speculative_config.ngram_n, + k=speculative_config.ngram_k) + + elif speculative_config.type == "eagle": + return EagleProposer() + + else: + raise ValueError( + f"Unsupported proposer type: {speculative_config.type}" + "Valid types: 'ngram', 'eagle'") diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 33289d05dabd..f1429e4650ec 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -4,8 +4,61 @@ import numpy as np from numba import jit +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.interface import ProposerInterface +from vllm.v1.spec_decode.utils import is_spec_decode_supported +from vllm.v1.worker.gpu_input_batch import InputBatch -class NgramProposer: + +class NgramProposer(ProposerInterface): + + def __init__(self, n, k): + self.n = n + self.k = k + # Trigger Numba JIT compilation for N-gram proposer. + # This usually takes less than 1 second. + self.propose( + np.zeros(1024, dtype=np.int32), + self.n, + self.k, + ) + + def generate_draft_token_ids( + self, input_batch: InputBatch, sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata) -> list[list[int]]: + ''' + Propose tokens based on input_batch and sampled_token_ids. + NOTE: This function will change the input_batch by writing + proposed tokens to token_ids_cpu. + ''' + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require top-p, top-k, etc. + req_id = input_batch.req_ids[i] + if not is_spec_decode_supported(req_id, input_batch): + draft_token_ids.append([]) + continue + + # Add sampled_token_ids to token_ids_cpu. + start_idx = input_batch.num_tokens_no_spec[i] + end_idx = start_idx + num_sampled_ids + input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + drafter_output = self.propose( + input_batch.token_ids_cpu[i, :end_idx], + self.n, + self.k, + ) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids def propose( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b186300a0033..28aa46e84590 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,9 +35,8 @@ ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.interface import create_proposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -151,18 +150,8 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - # TODO: find a better way to check if we are using ngram. - assert self.speculative_config.ngram_prompt_lookup_min, \ - "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: - self.drafter = NgramProposer() - # Trigger Numba JIT compilation for N-gram proposer. - # This usually takes less than 1 second. - self.drafter.propose( - np.zeros(1024, dtype=np.int32), - self.speculative_config.ngram_prompt_lookup_min, - self.speculative_config.num_speculative_tokens, - ) + self.drafter = create_proposer(self.speculative_config) self.rejection_sampler = RejectionSampler() # Request states. @@ -1117,8 +1106,8 @@ def execute_model( if not self.use_spec_decode: spec_token_ids = None else: - spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + spec_token_ids = self.drafter.generate_draft_token_ids( + self.input_batch, valid_sampled_token_ids, sampling_metadata) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1129,41 +1118,6 @@ def execute_model( prompt_logprobs_dict=prompt_logprobs_dict, ) - def generate_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require top-p, top-k, etc. - req_id = self.input_batch.req_ids[i] - if not is_spec_decode_supported(req_id, self.input_batch): - draft_token_ids.append([]) - continue - - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + num_sampled_ids - self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :end_idx], - self.speculative_config.ngram_prompt_lookup_min, - self.speculative_config.num_speculative_tokens, - ) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) - return draft_token_ids - def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 From 9b11746aa14ea01c05f2b7e80fb560b815a816fe Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 22 Mar 2025 08:30:14 -0700 Subject: [PATCH 2/3] scheduler --- vllm/v1/core/sched/scheduler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d002a19b08a4..cd59ad539d76 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -104,6 +104,12 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) + # Speculative decoding related. + self.num_lookahead_slots = 0 + if self.speculative_config and self.speculative_config.type == "eagle": + self.num_lookahead_slots = \ + self.speculative_config.num_speculative_tokens + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -150,7 +156,8 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - num_new_tokens = (request.num_tokens_with_spec - + num_new_tokens = (request.num_tokens_with_spec + + self.num_lookahead_slots - request.num_computed_tokens) num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 From 3a6cb36767333db6519170d29fa10c2136c9be8d Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 22 Mar 2025 08:31:05 -0700 Subject: [PATCH 3/3] scheduler --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cd59ad539d76..1cd306c85bee 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -108,7 +108,7 @@ def __init__( self.num_lookahead_slots = 0 if self.speculative_config and self.speculative_config.type == "eagle": self.num_lookahead_slots = \ - self.speculative_config.num_speculative_tokens + self.speculative_config.num_speculative_tokens def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: