diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index af65b6d38e02..2d28eca205b5 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from transformers import AutoTokenizer @@ -54,9 +55,14 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "ngram-eagle"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) + parser.add_argument( + "--num-speculative-tokens-per-method", + type=json.loads, + default='{"ngram": 2, "eagle": 2}', + ) parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) parser.add_argument("--tp", type=int, default=1) @@ -119,6 +125,21 @@ def main(args): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "ngram-eagle": + eagle_dir = args.eagle_dir + if eagle_dir is None: + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + args.num_spec_tokens = max( + args.num_speculative_tokens_per_method["ngram"], + args.num_speculative_tokens_per_method["eagle"], + ) + speculative_config = { + "method": "ngram-eagle", + "model": eagle_dir, + "num_speculative_tokens_per_method": args.num_speculative_tokens_per_method, + "prompt_lookup_max": args.prompt_lookup_max, + "prompt_lookup_min": args.prompt_lookup_min, + } elif args.method == "mtp": speculative_config = { "method": "mtp", @@ -156,6 +177,7 @@ def main(args): print("-" * 50) print(f"prompt: {output.prompt}") print(f"generated text: {output.outputs[0].text}") + print(f"num of generated tokens: {len(output.outputs[0].token_ids)}") print("-" * 50) try: @@ -185,6 +207,10 @@ def main(args): assert isinstance(metric, Vector) for pos in range(len(metric.values)): acceptance_counts[pos] += metric.values[pos] + elif metric.name == "vllm:generation_tokens": + assert isinstance(metric, Counter) + print(f"num generation tokens: {metric.value}") + total_tokens_generated = metric.value print("-" * 50) print(f"total_num_output_tokens: {total_num_output_tokens}") @@ -193,6 +219,14 @@ def main(args): print(f"num_accepted_tokens: {num_accepted_tokens}") acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1 print(f"mean acceptance length: {acceptance_length:.2f}") + num_tokens_generated_without_sd = total_tokens_generated - ( + num_drafts + num_accepted_tokens + ) + seq_normalized_acceptance_length = (total_tokens_generated) / ( + num_drafts + num_tokens_generated_without_sd + ) + print(f"num_tokens_generated_without_sd: {num_tokens_generated_without_sd}") + print(f"seq normalized acceptance length: {seq_normalized_acceptance_length:.2f}") print("-" * 50) # print acceptance at each token position diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 8f048775352e..6414bc5673d5 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -136,6 +136,8 @@ def test_ngram_correctness( "head_dim not being a a multiple of 32")), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("ngram-eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", @@ -150,8 +152,10 @@ def test_ngram_correctness( "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ - "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" + "qwen3_eagle3", "qwen2_5_vl_eagle3", + "llama3_eagle", "llama3_ngram_eagle", + "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -202,16 +206,32 @@ def test_eagle_correctness( torch.cuda.empty_cache() cleanup_dist_env_and_memory() - spec_llm = LLM( - model=model_name, - trust_remote_code=True, - tensor_parallel_size=tp_size, - speculative_config={ + if method == "ngram-eagle": + # Use ngram-eagle specific config + speculative_config = { + "method": method, + "model": spec_model_name, + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens_per_method": { + "ngram": 3, + "eagle": 3 + }, + "max_model_len": 2048, + } + else: + speculative_config = { "method": method, "model": spec_model_name, "num_speculative_tokens": 3, "max_model_len": 2048, - }, + } + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config=speculative_config, max_model_len=2048, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) diff --git a/tests/v1/spec_decode/test_ngram_eagle.py b/tests/v1/spec_decode/test_ngram_eagle.py new file mode 100644 index 000000000000..8b91d9138fd6 --- /dev/null +++ b/tests/v1/spec_decode/test_ngram_eagle.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) +from vllm.platforms import current_platform +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +model_dir = "meta-llama/Llama-3.1-8B-Instruct" +eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" +eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + +NUM_SPECULATIVE_TOKENS_NGRAM = 5 +NUM_SPECULATIVE_TOKENS_EAGLE = 3 +PROMPT_LOOKUP_MIN = 2 +PROMPT_LOOKUP_MAX = 5 +DEVICE = current_platform.device_type + + +def _create_vllm_config(num_speculative_tokens_ngram: int, + num_speculative_tokens_eagle: int): + model_config = ModelConfig(model=model_dir, + runner="generate", + max_model_len=100) + + # Choose model directory based on method + draft_model_dir = eagle_dir + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=draft_model_dir, + method="ngram-eagle", + num_speculative_tokens_per_method={ + "ngram": num_speculative_tokens_ngram, + "eagle": num_speculative_tokens_eagle + }, + prompt_lookup_max=PROMPT_LOOKUP_MAX, + prompt_lookup_min=PROMPT_LOOKUP_MIN, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig()) + + return vllm_config + + +def test_proposer_config(): + + vllm_config = _create_vllm_config(NUM_SPECULATIVE_TOKENS_NGRAM, + NUM_SPECULATIVE_TOKENS_EAGLE) + + # ngram proposer + ngram_proposer = NgramProposer(vllm_config=vllm_config) + assert ngram_proposer.k == NUM_SPECULATIVE_TOKENS_NGRAM + assert ngram_proposer.min_n == PROMPT_LOOKUP_MIN + assert ngram_proposer.max_n == PROMPT_LOOKUP_MAX + + # eagle proposer + eagle_proposer = EagleProposer(vllm_config=vllm_config, + device=current_platform.device_type) + assert eagle_proposer.num_speculative_tokens == NUM_SPECULATIVE_TOKENS_EAGLE + + +@pytest.mark.parametrize( + "test_value", + [ + { + "sampled_token_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], + # ngram draft is empty + "propose_ngram_draft_token_ids": [[]] + }, + { + "sampled_token_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3]], + # ngram draft is not empty + "propose_ngram_draft_token_ids": [[4, 5, 6, 7, 8]] + } + ]) +@pytest.mark.parametrize("pp_size", [1, 2]) +@mock.patch('vllm.v1.worker.gpu_model_runner.get_pp_group') +@mock.patch( + 'vllm.v1.worker.gpu_model_runner.GPUModelRunner.propose_ngram_draft_token_ids' +) +@mock.patch('vllm.v1.worker.gpu_model_runner.EagleProposer.propose', + return_value=torch.tensor([[0, 1, 2]])) +@mock.patch('vllm.v1.worker.gpu_model_runner.EagleProposer.prepare_inputs', + return_value=(None, 0)) +def test_propose_draft_token_ids( + mock_eagle_proposer_prepare_input, + mock_eagle_proposer_propose, + mock_propose_ngram_draft_token_ids, + mock_get_pp_group, + test_value, + pp_size, +): + + vllm_config = _create_vllm_config(NUM_SPECULATIVE_TOKENS_NGRAM, + NUM_SPECULATIVE_TOKENS_EAGLE) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Setup mock for pp group to return the appropriate value for world size + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = pp_size + mock_get_pp_group.return_value = mock_pp_group + + sampled_token_ids = test_value["sampled_token_ids"] + propose_ngram_draft_token_ids = test_value["propose_ngram_draft_token_ids"] + + # with min matching ngram = 2, max matching ngram = 3 + # we will find the prefix [1, 2, 3] in the history + # and speculate [4, 5, 6, 7, 8] for ngram + expected_ngram_proposals = [[4, 5, 6, 7, 8]] + expected_eagle_proposals = [[ + i for i in range(NUM_SPECULATIVE_TOKENS_EAGLE) + ]] + mock_propose_ngram_draft_token_ids \ + .return_value = propose_ngram_draft_token_ids + + # doesnt matter what this is for this test: START + scheduler_output = mock.MagicMock() + scheduler_output.total_num_scheduled_tokens = 1 + max( + vllm_config.speculative_config. + num_speculative_tokens_per_method["ngram"], vllm_config. + speculative_config.num_speculative_tokens_per_method["eagle"]) + hidden_states = torch.randn(len(sampled_token_ids[0]), 4096) + sample_hidden_states = None + aux_hidden_states = None + spec_decode_metadata = mock.MagicMock() + spec_decode_metadata.num_draft_tokens = [ + max(NUM_SPECULATIVE_TOKENS_NGRAM, NUM_SPECULATIVE_TOKENS_EAGLE) + ] + common_attn_metadata = None + sampling_metadata = None + + # set runner attributes that would normally be set during init + runner.supports_mm_inputs = False + + mock_positions = mock.MagicMock() + mock_positions_instance = mock_positions.return_value + mock_positions_instance.gpu = torch.tensor([0]) + runner.positions = mock_positions_instance + + mock_input_ids = mock.MagicMock() + mock_input_ids_instance = mock_input_ids.return_value + mock_input_ids_instance.gpu = torch.tensor([0]) + runner.input_ids = mock_input_ids_instance + + mock_req_ids = mock.MagicMock() + mock_req_ids.return_value = ["0"] + # doesnt matter what this is for this test: END + + final_draft = runner.propose_draft_token_ids( + scheduler_output=scheduler_output, + sampled_token_ids=sampled_token_ids, + sampling_metadata=sampling_metadata, + hidden_states=hidden_states, + sample_hidden_states=sample_hidden_states, + aux_hidden_states=aux_hidden_states, + spec_decode_metadata=spec_decode_metadata, + common_attn_metadata=common_attn_metadata, + ) + + # case 1: ngram draft is empty. Eagle draft is used + if sampled_token_ids == [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]: + assert final_draft == expected_eagle_proposals, \ + "ngram-eagle should have selected eagle draft" + # case 2: ngram draft is not empty. Ngram draft is used + elif sampled_token_ids == [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3]]: + assert final_draft == expected_ngram_proposals, \ + "ngram-eagle should have selected ngram draft" + else: + raise ValueError("unexpected sampled_token_ids") diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d5c6d1d4d866..3958086de3f7 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -29,10 +29,10 @@ logger = init_logger(__name__) -SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", - "longcat_flash_mtp", "mtp"] +SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "ngram-eagle", + "medusa", "mlp_speculator", "draft_model", + "deepseek_mtp", "ernie_mtp", "qwen3_next_mtp", + "mimo_mtp", "longcat_flash_mtp", "mtp"] MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", "qwen3_next_mtp", "longcat_flash_mtp") @@ -47,6 +47,9 @@ class SpeculativeConfig: num_speculative_tokens: SkipValidation[int] = None # type: ignore """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" + num_speculative_tokens_per_method: Optional[dict[str, int]] = None + """The number of speculative tokens for each method, if provided. Max of + the values will be used if `num_speculative_tokens` is not provided.""" model: Optional[str] = None """The name of the draft model, eagle head, or additional weights, if provided.""" @@ -238,6 +241,25 @@ def __post_init__(self): "num_speculative_tokens was provided but without " "speculative model.") + # set num_speculative_tokens from num_speculative_tokens_per_method + # for methods like ngram-eagle + if self.num_speculative_tokens_per_method is not None: + assert all( + isinstance(v, int) and v > 0 + for v in self.num_speculative_tokens_per_method.values()), ( + "All values in num_speculative_tokens_per_method must be " + "positive integers.") + max_num_speculative_tokens = max( + self.num_speculative_tokens_per_method.values()) + if self.num_speculative_tokens is None: + self.num_speculative_tokens = max_num_speculative_tokens + else: + assert self.num_speculative_tokens <= \ + max_num_speculative_tokens, ( + "num_speculative_tokens should be None or must be" + " less than or equal to the " + "max value in num_speculative_tokens_per_method.") + # Automatically configure the method for ngram when "model" is used # instead of "method" if self.method is None and (self.model is not None @@ -247,6 +269,8 @@ def __post_init__(self): if self.method in ("ngram", "[ngram]"): # Unified to "ngram" internally self.method = "ngram" + + if self.method in ("ngram", "ngram-eagle"): # Set default values if not provided if (self.prompt_lookup_min is None and self.prompt_lookup_max is None): @@ -277,9 +301,13 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config - else: - self.prompt_lookup_max = 0 - self.prompt_lookup_min = 0 + + # allow ngram-eagle to use this code block similar to eagle + if self.method not in ("ngram"): + + if self.method != "ngram-eagle": + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 if self.model is not None: # TODO: Move this import to the top once `ModelConfig` @@ -311,7 +339,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ('eagle', 'eagle3'): + if self.method in ('eagle', 'eagle3', 'ngram-eagle'): pass # examples: # yuhuili/EAGLE-LLaMA3-Instruct-8B @@ -353,7 +381,7 @@ def __post_init__(self): "eagle, or mtp.") # Replace hf_config for EAGLE draft_model - if self.method in ("eagle", "eagle3"): + if self.method in ("eagle", "eagle3", "ngram-eagle"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " @@ -425,6 +453,12 @@ def __post_init__(self): self.target_parallel_config, self.draft_tensor_parallel_size)) + if self.use_ngram() and not self.disable_padded_drafter_batch: + logger.warning( + "padded_drafter_batch has to be disabled with ngram. " + "Setting it disable_padded_drafter_batch to True.") + self.disable_padded_drafter_batch = True + @staticmethod def _maybe_override_draft_max_model_len( speculative_max_model_len: Optional[int], @@ -523,6 +557,27 @@ def _verify_args(self) -> Self: "speculative model unless the draft model config contains an " "n_predict parameter.") + if self.method == "ngram-eagle": + assert self.num_speculative_tokens_per_method is not None, ( + "num_speculative_tokens_per_method must be provided for " + "ngram-eagle method.") + assert "ngram" in self.num_speculative_tokens_per_method, ( + "num_speculative_tokens_per_method must contain ngram key for " + "ngram-eagle method.") + assert "eagle" in self.num_speculative_tokens_per_method, ( + "num_speculative_tokens_per_method must contain eagle key for " + "ngram-eagle method.") + ngram_speculative_tokens = \ + self.num_speculative_tokens_per_method["ngram"] + eagle_speculative_tokens = \ + self.num_speculative_tokens_per_method["eagle"] + if self.num_speculative_tokens != \ + max(ngram_speculative_tokens, eagle_speculative_tokens): + raise ValueError( + "num_speculative_tokens must be the max value in " + "num_speculative_tokens_per_method for ngram-eagle method." + ) + if self.num_speculative_tokens <= 0: raise ValueError("Expected num_speculative_tokens to be greater " f"than zero ({self.num_speculative_tokens}).") @@ -559,7 +614,10 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "mtp") + return self.method in ("eagle", "eagle3", "ngram-eagle", "mtp") + + def use_ngram(self) -> bool: + return self.method == "ngram" or self.method == "ngram-eagle" def __repr__(self) -> str: method = self.method diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 444ed70de3d0..53d32a730ded 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -47,7 +47,7 @@ def __init__(self, # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 - if method == "eagle": + if method in ("eagle", "ngram-eagle"): assert self.model is not None, \ "model should not be None when method is eagle" kwargs["architectures"] = [ @@ -63,8 +63,9 @@ def __init__(self, else f"Eagle3{arch}" for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. " - "Supported methods are eagle and eagle3.") + raise ValueError( + f"Invalid method {method}. " + "Supported methods are eagle, ngram-eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..aa7233e04bff 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -59,8 +59,21 @@ def __init__( self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) + + if self.method == "ngram-eagle": + self.num_speculative_tokens = ( + self.speculative_config. + num_speculative_tokens_per_method["eagle"]) + else: + self.num_speculative_tokens = ( + self.speculative_config.num_speculative_tokens) + + logger.info( + "EagleProposer: method=%s, num_speculative_tokens=%s", + self.method, + self.num_speculative_tokens, + ) + self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index aed050a3540c..0376c61a640b 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -6,6 +6,9 @@ from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) class NgramProposer: @@ -22,7 +25,15 @@ def __init__(self, vllm_config: VllmConfig): # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. - self.k = vllm_config.speculative_config.num_speculative_tokens + self.method = vllm_config.speculative_config.method + if self.method == "ngram-eagle": + self.k = vllm_config \ + .speculative_config \ + .num_speculative_tokens_per_method["ngram"] + else: + self.k = vllm_config \ + .speculative_config \ + .num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len @@ -58,6 +69,13 @@ def __init__(self, vllm_config: VllmConfig): self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), set()) + logger.info( + "NgramProposer: min_n=%s, max_n=%s, k=%s, max_model_len=%s", + self.min_n, + self.max_n, + self.k, + self.max_model_len, + ) def batch_propose( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9941cacae8ab..f46994da0ade 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -280,19 +280,27 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. + found_draft = False if self.speculative_config and get_pp_group().is_last_rank: - if self.speculative_config.method == "ngram": - self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + # use ifs and not elifs to allow multiple + # draft models to be initialized + if self.speculative_config.method == "ngram" \ + or self.speculative_config.method == "ngram-eagle": + self.drafter_ngram = NgramProposer(self.vllm_config) + found_draft = True + if self.speculative_config.use_eagle(): + self.drafter_eagle = EagleProposer(self.vllm_config, + self.device, + self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True - elif self.speculative_config.method == "medusa": + found_draft = True + if self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device) # type: ignore - else: + found_draft = True + if not found_draft: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") self.rejection_sampler = RejectionSampler() @@ -1244,8 +1252,8 @@ def _prepare_inputs( if (self.speculative_config and spec_decode_common_attn_metadata is None): - if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] + if isinstance(self.drafter_eagle, EagleProposer): + if (self.drafter_eagle.attn_layer_names[0] in kv_cache_group_spec.layer_names): spec_decode_common_attn_metadata = common_attn_metadata else: @@ -2484,7 +2492,8 @@ def propose_draft_token_ids(sampled_token_ids): use_padded_batch_for_eagle = self.speculative_config and \ self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + not self.speculative_config.disable_padded_drafter_batch and \ + not self.speculative_config.use_ngram() effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len @@ -2569,14 +2578,17 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if self.speculative_config.method == "ngram": + if self.speculative_config.method == "ngram" \ + or self.speculative_config.method == "ngram-eagle": + assert isinstance(self.drafter_ngram, NgramProposer) assert isinstance(sampled_token_ids, list) - assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.drafter.propose( + draft_token_ids = self.drafter_ngram.propose( sampled_token_ids, self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs) + if self.speculative_config.method == "ngram-eagle": + draft_token_ids_ngram = draft_token_ids elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2600,8 +2612,8 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config.use_eagle(): + assert isinstance(self.drafter_eagle, EagleProposer) if self.speculative_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -2610,7 +2622,7 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list), \ "sampled_token_ids should be a python list when" \ "padded-batch is disabled." - next_token_ids = self.drafter.prepare_next_token_ids_cpu( + next_token_ids = self.drafter_eagle.prepare_next_token_ids_cpu( sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens) else: @@ -2622,7 +2634,7 @@ def propose_draft_token_ids( "sampled_token_ids should be a torch.Tensor when" \ "padded-batch is enabled." next_token_ids, valid_sampled_tokens_count = \ - self.drafter.prepare_next_token_ids_padded( + self.drafter_eagle.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, @@ -2647,14 +2659,14 @@ def propose_draft_token_ids( if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( + self.drafter_eagle.prepare_inputs( common_attn_metadata, sampled_token_ids, spec_decode_metadata.num_draft_tokens) else: common_attn_metadata, token_indices, \ token_indices_to_sample =\ - self.drafter.prepare_inputs_padded( + self.drafter_eagle.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count) @@ -2676,7 +2688,7 @@ def propose_draft_token_ids( else: mm_embed_inputs = None - draft_token_ids = self.drafter.propose( + draft_token_ids = self.drafter_eagle.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -2687,6 +2699,25 @@ def propose_draft_token_ids( mm_embed_inputs=mm_embed_inputs, ) + if self.speculative_config.method == "ngram-eagle": + draft_token_ids_eagle = draft_token_ids + + if self.speculative_config.method == "ngram-eagle": + assert draft_token_ids_ngram is not None, "ngram proposer failed" + assert draft_token_ids_eagle is not None, "eagle proposer failed" + # eagle draft is torch but we need list + draft_token_ids_eagle = draft_token_ids_eagle.tolist() + draft_token_ids = [] + + # combine ngram and eagle drafts + # prefer ngram drafts when available + # choose eagle drafts when ngram drafts are empty + for bid in range(len(draft_token_ids_ngram)): + if len(draft_token_ids_ngram[bid]): + draft_token_ids.append(draft_token_ids_ngram[bid]) + else: + draft_token_ids.append(draft_token_ids_eagle[bid]) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: @@ -2745,6 +2776,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) + if hasattr(self, "drafter_eagle"): + logger.info("Loading eagle drafter model...") + self.drafter_eagle.load_model(self.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( @@ -3266,8 +3300,8 @@ def _dummy_run( hidden_states = outputs if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + assert isinstance(self.drafter_eagle, EagleProposer) + self.drafter_eagle.dummy_run(num_tokens) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -4096,10 +4130,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + assert isinstance(self.drafter_eagle, EagleProposer) # validate all draft model layers belong to the same kv cache # group - self.drafter.validate_same_kv_cache_group(kv_cache_config) + self.drafter_eagle.validate_same_kv_cache_group(kv_cache_config) if has_kv_transfer_group(): kv_transfer_group = get_kv_transfer_group()