From 19b8b73aecb52b1fb93572f22dafe0ddee3a382b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 2 Apr 2025 15:19:34 -0700 Subject: [PATCH 01/17] register model Signed-off-by: LiuXiaoxuanPKU --- examples/offline_inference/eagle.py | 1 + vllm/model_executor/models/llama_eagle.py | 120 ++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/eagle.py | 5 +- vllm/v1/spec_decode/eagle.py | 28 ++--- 5 files changed, 134 insertions(+), 21 deletions(-) create mode 100644 vllm/model_executor/models/llama_eagle.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index db5012bae293..f7b4765f78dd 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -70,6 +70,7 @@ max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ + "method": "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py new file mode 100644 index 000000000000..d96434817aba --- /dev/null +++ b/vllm/model_executor/models/llama_eagle.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__(config, layer_id, prefix) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if layer_id == 0: + del self.input_layernorm + self.input_layernorm = lambda x: x + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix("embed_tokens", prefix), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + config, + i, + prefix=maybe_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states + residual + + +class LlamaForCausalLMEagle(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config) + config = vllm_config.model_config.hf_config + self.config = config + + self.model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + # Llama 3.2 1B Instruct set tie_word_embeddings to True + # Llama 3.1 8B Instruct set tie_word_embeddings to False + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + loader.load_weight(name, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6ead6509bfe8..899309a69766 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -200,6 +200,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "LlamaForCausalLMEagle": ("llama_eagle", "LlamaForCausalLMEagle"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index dd806061ff58..c33493926e82 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -5,6 +5,7 @@ from transformers import AutoConfig, PretrainedConfig +import vllm.envs as envs from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config @@ -41,8 +42,10 @@ def __init__(self, self.truncated_vocab_size = self.model.vocab_size if \ truncated_vocab_size is None else truncated_vocab_size - if "architectures" not in kwargs: + if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] + else: + kwargs["architectures"] = ["LlamaForCausalLMEagle"] super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 57c6b652593d..172770be7681 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -6,6 +6,7 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context +from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -176,26 +177,13 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - self.model = DummyEagleModel() - self.model.get_input_embeddings = target_model.get_input_embeddings - self.model.compute_logits = target_model.compute_logits - - -# FIXME(woosuk): This is a dummy model for testing. -# Remove this once we have a real model. -class DummyEagleModel(nn.Module): - - def __init__(self): - super().__init__() - - def forward( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - positions: torch.Tensor, - ) -> torch.Tensor: - input_embeddings = self.get_input_embeddings(input_ids) - return hidden_states + input_embeddings # Dummy return. + import copy + eagle_vllm_config = copy.deepcopy(self.vllm_config) + eagle_vllm_config.model_config = \ + self.vllm_config.speculative_config.draft_model_config + print(eagle_vllm_config) + self.model = get_model(vllm_config=eagle_vllm_config) + print(f"Loaded draft model: {self.model}") # FIXME(woosuk): The logic here is duplicated with the main sampling code. From b566516f012ffaa4b4b1ff96cd72764b2deddb01 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 2 Apr 2025 16:57:47 -0700 Subject: [PATCH 02/17] fix config and init Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 24 +++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index d96434817aba..4f906c031a73 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -27,7 +27,7 @@ def __init__( layer_id: int = 0, prefix: str = "", ) -> None: - super().__init__(config, layer_id, prefix) + super().__init__(config, prefix=prefix) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 @@ -83,14 +83,12 @@ def forward( class LlamaForCausalLMEagle(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config) + nn.Module.__init__(self) config = vllm_config.model_config.hf_config self.config = config + self.model = LlamaModel(config=config, + prefix=maybe_prefix(prefix, "eagle")) - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size @@ -112,9 +110,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + print("===========") + print_module_names(self.model) + print("===========") loader = AutoWeightsLoader(self) for name, loaded_weight in weights: if "lm_head" not in name: - name = "model." + name - loader.load_weight(name, loaded_weight) + name = "eagle." + name + loader.load_weights(name, loaded_weight) + + +def print_module_names(module, prefix=""): + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + print(full_name) + print_module_names(child, full_name) From bc1b7d07f16ed5115f2ca2fb2b80a180d538b46f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 2 Apr 2025 18:00:03 -0700 Subject: [PATCH 03/17] torch compile Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 39 +++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 4f906c031a73..b64c5a4755bf 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -6,6 +6,7 @@ import torch.nn as nn from transformers import LlamaConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -36,26 +37,29 @@ def __init__( self.input_layernorm = lambda x: x +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, - config: LlamaConfig, + *, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - prefix=maybe_prefix("embed_tokens", prefix), + prefix=maybe_prefix(prefix, "embed_tokens"), ) self.layers = nn.ModuleList([ LlamaDecoderLayer( config, i, - prefix=maybe_prefix(f"layers.{i}", prefix), + prefix=maybe_prefix(prefix, f"layers.{i}"), ) for i in range(config.num_hidden_layers) ]) self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) @@ -86,8 +90,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config self.config = config - self.model = LlamaModel(config=config, - prefix=maybe_prefix(prefix, "eagle")) + self.model = LlamaModel(vllm_config=vllm_config, prefix="") self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size @@ -109,20 +112,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.truncated_vocab_size, logit_scale) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions, hidden_states) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - print("===========") - print_module_names(self.model) - print("===========") loader = AutoWeightsLoader(self) + model_weights = {} for name, loaded_weight in weights: if "lm_head" not in name: - name = "eagle." + name - loader.load_weights(name, loaded_weight) - + name = "model." + name + print(name) + model_weights[name] = loaded_weight -def print_module_names(module, prefix=""): - for name, child in module.named_children(): - full_name = f"{prefix}.{name}" if prefix else name - print(full_name) - print_module_names(child, full_name) + loader.load_weights( + self.maybe_remap_mistral(name, loaded_weight) + for name, loaded_weight in weights) From e5ad74827c29f4f6d9bde2855cebad7a8ccc13a9 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 2 Apr 2025 19:00:16 -0700 Subject: [PATCH 04/17] remove lambda, add prefix Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index b64c5a4755bf..e49c339b7091 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -6,7 +6,6 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -34,10 +33,9 @@ def __init__( # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 if layer_id == 0: del self.input_layernorm - self.input_layernorm = lambda x: x + self.input_layernorm = nn.Identity() -@support_torch_compile class LlamaModel(nn.Module): def __init__( @@ -90,7 +88,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config self.config = config - self.model = LlamaModel(vllm_config=vllm_config, prefix="") + self.model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size From 41e4d03c8dfdeeb4d01da68d788e7fcaca173f7f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 3 Apr 2025 10:40:25 -0700 Subject: [PATCH 05/17] runnable v1 Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 13 +++++++------ vllm/v1/spec_decode/eagle.py | 22 +++++++++++++--------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index e49c339b7091..345aaa32c51e 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -57,7 +57,7 @@ def __init__( LlamaDecoderLayer( config, i, - prefix=maybe_prefix(prefix, f"layers.{i}"), + prefix=maybe_prefix(prefix, f"layers.{i + 33}"), ) for i in range(config.num_hidden_layers) ]) self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) @@ -88,8 +88,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config self.config = config - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.eagle_model = LlamaModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "eagle_model")) self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size @@ -97,7 +98,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + self.lm_head = self.eagle_model.embed_tokens else: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -117,7 +118,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - return self.model(input_ids, positions, hidden_states) + return self.eagle_model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -125,7 +126,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): model_weights = {} for name, loaded_weight in weights: if "lm_head" not in name: - name = "model." + name + name = "eagle_model." + name print(name) model_weights[name] = loaded_weight diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 172770be7681..597cf67396cd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -23,7 +23,8 @@ def __init__( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, - device=device) + device=device, + dtype=torch.int32) def propose( self, @@ -55,7 +56,9 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] input_ids[last_token_indices] = next_token_ids - seq_lens = target_positions[last_token_indices] + 1 + # FA requires seq_len to have dtype int32. + seq_lens = (target_positions[last_token_indices] + 1).int() + # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() @@ -99,7 +102,7 @@ def propose( hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size] + attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. input_ids = draft_token_ids_list[-1] @@ -177,13 +180,14 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - import copy - eagle_vllm_config = copy.deepcopy(self.vllm_config) - eagle_vllm_config.model_config = \ + old_config = self.vllm_config.model_config + + self.vllm_config.model_config = \ self.vllm_config.speculative_config.draft_model_config - print(eagle_vllm_config) - self.model = get_model(vllm_config=eagle_vllm_config) - print(f"Loaded draft model: {self.model}") + self.model = get_model(vllm_config=self.vllm_config) + + # Resume config + self.vllm_config.model_config = old_config # FIXME(woosuk): The logic here is duplicated with the main sampling code. From 10b107d2ae2a5a5c750684d9a27f0b40841e6420 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 3 Apr 2025 14:19:33 -0700 Subject: [PATCH 06/17] change the way of getting vllm_config, cleanup Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/model_loader/loader.py | 4 +- vllm/model_executor/models/llama_eagle.py | 58 ++++++++++------------ vllm/v1/spec_decode/eagle.py | 31 +++++++++--- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 5649cf2dd2cf..850d4d130046 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -403,7 +403,7 @@ def _hpu_weights_iterator(iterator: Generator): return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) - def _get_all_weights( + def get_all_weights( self, model_config: ModelConfig, model: nn.Module, @@ -442,7 +442,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( - self._get_all_weights(model_config, model)) + self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 345aaa32c51e..91198960bb52 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -6,11 +6,11 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import VllmConfig +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -41,26 +41,27 @@ class LlamaModel(nn.Module): def __init__( self, *, - vllm_config: VllmConfig, + model_config: ModelConfig, + start_layer_id: int = 0, prefix: str = "", ) -> None: super().__init__() - config = vllm_config.model_config.hf_config - self.config = config - self.vocab_size = config.vocab_size + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) self.layers = nn.ModuleList([ LlamaDecoderLayer( - config, + self.config, i, - prefix=maybe_prefix(prefix, f"layers.{i + 33}"), - ) for i in range(config.num_hidden_layers) + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) ]) - self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size) def forward( self, @@ -84,33 +85,26 @@ def forward( class LlamaForCausalLMEagle(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - self.config = config - self.eagle_model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "eagle_model")) - - self.truncated_vocab_size = config.truncated_vocab_size - self.unpadded_vocab_size = self.truncated_vocab_size + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: - self.lm_head = self.eagle_model.embed_tokens + self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, + self.config.vocab_size, + self.config.hidden_size, ) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.truncated_vocab_size, - logit_scale) + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) def forward( self, @@ -118,7 +112,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - return self.eagle_model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -126,7 +120,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): model_weights = {} for name, loaded_weight in weights: if "lm_head" not in name: - name = "eagle_model." + name + name = "model." + name print(name) model_weights[name] = loaded_weight diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 597cf67396cd..c6ff75accf34 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,9 +4,11 @@ import triton import triton.language as tl -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context -from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.llama_eagle import LlamaForCausalLMEagle from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -180,14 +182,27 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - old_config = self.vllm_config.model_config + loader = get_model_loader(self.vllm_config.load_config) + target_layer_num = self.vllm_config.model_config.get_num_layers( + self.vllm_config.parallel_config) - self.vllm_config.model_config = \ + draft_model_config = \ self.vllm_config.speculative_config.draft_model_config - self.model = get_model(vllm_config=self.vllm_config) - - # Resume config - self.vllm_config.model_config = old_config + # FIXME(lily): This does not handle with distributed inference. + target_device = self.vllm_config.device_config.device + # We need to set the vllm_config here to register attention + # layers in the forward context. + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + self.model = LlamaForCausalLMEagle( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) # FIXME(woosuk): The logic here is duplicated with the main sampling code. From 59ee450306d3d719f78ad60c77ba9b739bc5cb11 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 3 Apr 2025 14:28:10 -0700 Subject: [PATCH 07/17] minor Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 91198960bb52..ae62994a669e 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -121,8 +121,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "lm_head" not in name: name = "model." + name - print(name) - model_weights[name] = loaded_weight + model_weights[name] = loaded_weight loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) From 2ce60846d3bcd2e2fe6ca291e5a38a811c4f7941 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 5 Apr 2025 22:18:15 -0700 Subject: [PATCH 08/17] fix weights loading Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 57 ++++++++++++++++------- vllm/v1/spec_decode/eagle.py | 1 + 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index ae62994a669e..61953cbbd156 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Tuple +from typing import Iterable, Set, Tuple import torch import torch.nn as nn @@ -10,7 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) @@ -61,7 +62,8 @@ def __init__( ) for i in range(self.config.num_hidden_layers) ]) self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size) + self.config.hidden_size, + bias=False) def forward( self, @@ -82,6 +84,35 @@ def forward( ) return hidden_states + residual + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class LlamaForCausalLMEagle(LlamaForCausalLM): @@ -92,16 +123,6 @@ def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): start_layer_id=start_layer_id, prefix="model") - # Llama 3.2 1B Instruct set tie_word_embeddings to True - # Llama 3.1 8B Instruct set tie_word_embeddings to False - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - ) - logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, scale=logit_scale) @@ -115,7 +136,11 @@ def forward( return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) model_weights = {} for name, loaded_weight in weights: @@ -123,6 +148,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = "model." + name model_weights[name] = loaded_weight - loader.load_weights( - self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights) + loader.load_weights(model_weights.items()) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index c6ff75accf34..503f2fa72848 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -203,6 +203,7 @@ def load_model(self, target_model: nn.Module) -> None: loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) + self.model.lm_head = target_model.lm_head # FIXME(woosuk): The logic here is duplicated with the main sampling code. From b0388cff57fe259e1eb681dd3945fda896a90f4b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sat, 5 Apr 2025 23:26:14 -0700 Subject: [PATCH 09/17] fix layernorm and fix cudagraph Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 61953cbbd156..4bc419b51176 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -25,14 +25,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, config: LlamaConfig, - layer_id: int = 0, + disable_input_layernorm: bool, prefix: str = "", ) -> None: super().__init__(config, prefix=prefix) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 - if layer_id == 0: + if disable_input_layernorm: del self.input_layernorm self.input_layernorm = nn.Identity() @@ -57,7 +57,7 @@ def __init__( self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, - i, + i == 0, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), ) for i in range(self.config.num_hidden_layers) ]) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 513806332efe..2da743a0e32b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1171,11 +1171,15 @@ def execute_model( if spec_decode_metadata is None: # input_ids can be None for multimodal models. + # We need to slice token_ids, positions, and hidden_states + # because the eagle head does not use cuda graph and should + # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions - target_hidden_states = hidden_states + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens, :] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc + token_indices = None else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens From a6f46cfa2f8c2b615746222ca608999c6cc4515f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 7 Apr 2025 10:53:50 -0700 Subject: [PATCH 10/17] example Signed-off-by: LiuXiaoxuanPKU --- examples/offline_inference/eagle.py | 52 +++++++++++++++++++---------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index f7b4765f78dd..66a46f0e4c51 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -2,10 +2,11 @@ import argparse import json import os +import time from transformers import AutoTokenizer -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams, envs parser = argparse.ArgumentParser() @@ -31,7 +32,8 @@ print(args) model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" +# eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" +eagle_dir = "yuhuili/EAGLE-LLaMA3-Instruct-8B" max_model_len = 2048 @@ -69,28 +71,42 @@ max_model_len=max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, - speculative_config={ - "method": "eagle", - "model": eagle_dir, - "num_speculative_tokens": args.num_spec_tokens, - "draft_tensor_parallel_size": args.draft_tp, - "max_model_len": max_model_len, - }, + # speculative_config={ + # "method": "eagle", + # "model": eagle_dir, + # "num_speculative_tokens": args.num_spec_tokens, + # "draft_tensor_parallel_size": args.draft_tp, + # "max_model_len": max_model_len, + # }, disable_log_stats=False, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) +# warmup outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) -# calculate the average number of accepted tokens per forward pass, +1 is -# to account for the token from the target model that's always going to be -# accepted -acceptance_counts = [0] * (args.num_spec_tokens + 1) -for output in outputs: - for step, count in enumerate(output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count +start = time.time() +repeat = 5 +for _ in range(repeat): + outputs = llm.generate(prompt_token_ids=prompt_ids, + sampling_params=sampling_params) +end = time.time() -print(f"mean acceptance length: \ - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") +for output in outputs: + print(output.outputs[0].text) + +print(f"total time: {(end - start) / repeat:.2f}s") +if not envs.VLLM_USE_V1: + # calculate the average number of accepted tokens per forward pass, +1 is + # to account for the token from the target model that's always going to be + # accepted + acceptance_counts = [0] * (args.num_spec_tokens + 1) + for output in outputs: + for step, count in enumerate( + output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count + + print(f"mean acceptance length: \ + {sum(acceptance_counts) / acceptance_counts[0]:.2f}") From 5bb90c424c756b6fd1529ec2491684f0b6806c66 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 7 Apr 2025 20:22:24 -0700 Subject: [PATCH 11/17] revert example and test Signed-off-by: LiuXiaoxuanPKU --- examples/offline_inference/eagle.py | 52 ++++++++++------------------- vllm/v1/spec_decode/eagle.py | 5 ++- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 66a46f0e4c51..b6771e36fc01 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -2,11 +2,10 @@ import argparse import json import os -import time from transformers import AutoTokenizer -from vllm import LLM, SamplingParams, envs +from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() @@ -32,8 +31,7 @@ print(args) model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -# eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" -eagle_dir = "yuhuili/EAGLE-LLaMA3-Instruct-8B" +eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" max_model_len = 2048 @@ -71,42 +69,28 @@ max_model_len=max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, - # speculative_config={ - # "method": "eagle", - # "model": eagle_dir, - # "num_speculative_tokens": args.num_spec_tokens, - # "draft_tensor_parallel_size": args.draft_tp, - # "max_model_len": max_model_len, - # }, + speculative_config={ + "method": "eagle", + "model": eagle_dir, + "num_speculative_tokens": args.num_spec_tokens, + "draft_tensor_parallel_size": args.draft_tp, + "max_model_len": max_model_len, + }, disable_log_stats=False, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) -# warmup outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) -start = time.time() -repeat = 5 -for _ in range(repeat): - outputs = llm.generate(prompt_token_ids=prompt_ids, - sampling_params=sampling_params) -end = time.time() - +# calculate the average number of accepted tokens per forward pass, +1 is +# to account for the token from the target model that's always going to be +# accepted +acceptance_counts = [0] * (args.num_spec_tokens + 1) for output in outputs: - print(output.outputs[0].text) - -print(f"total time: {(end - start) / repeat:.2f}s") -if not envs.VLLM_USE_V1: - # calculate the average number of accepted tokens per forward pass, +1 is - # to account for the token from the target model that's always going to be - # accepted - acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count - - print(f"mean acceptance length: \ - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") + for step, count in enumerate(output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count + +print(f"mean acceptance length: \ + {sum(acceptance_counts) / acceptance_counts[0]:.2f}") \ No newline at end of file diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 503f2fa72848..d1940777b00a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -24,7 +24,10 @@ def __init__( self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, device=device, dtype=torch.int32) From 560eaee61d6eac3cf60868e3bcd2807442d7e2b5 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 8 Apr 2025 14:50:15 -0700 Subject: [PATCH 12/17] fix comments Signed-off-by: LiuXiaoxuanPKU --- vllm/model_executor/models/llama_eagle.py | 2 +- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/configs/eagle.py | 2 +- vllm/v1/spec_decode/eagle.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 3 +-- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 4bc419b51176..28ad6128c4f1 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -114,7 +114,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class LlamaForCausalLMEagle(LlamaForCausalLM): +class EagleLlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): nn.Module.__init__(self) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 340281e7d709..5ffccb562782 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -204,7 +204,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), - "LlamaForCausalLMEagle": ("llama_eagle", "LlamaForCausalLMEagle"), + "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index c33493926e82..3a9ad3e0ffc8 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -45,7 +45,7 @@ def __init__(self, if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] else: - kwargs["architectures"] = ["LlamaForCausalLMEagle"] + kwargs["architectures"] = ["EagleLlamaForCausalLM"] super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4d2ba4b18725..2322463c0713 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.llama_eagle import LlamaForCausalLMEagle +from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -198,7 +198,7 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = LlamaForCausalLMEagle( + self.model = EagleLlamaForCausalLM( model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0c45866afc5a..0d99a86d93af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1196,10 +1196,9 @@ def execute_model( # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens, :] + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc - token_indices = None else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens From 32bbd9847320a09042b9a5bb26a4de7573bc3bd7 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 8 Apr 2025 22:10:27 -0700 Subject: [PATCH 13/17] add tests Signed-off-by: LiuXiaoxuanPKU --- ...ram_spec_decode.py => test_spec_decode.py} | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) rename tests/v1/e2e/{test_ngram_spec_decode.py => test_spec_decode.py} (65%) diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_spec_decode.py similarity index 65% rename from tests/v1/e2e/test_ngram_spec_decode.py rename to tests/v1/e2e/test_spec_decode.py index 7c7c2f02c078..fcca4bdc220d 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -53,6 +53,11 @@ def model_name(): return "meta-llama/Meta-Llama-3-8B-Instruct" +@pytest.fixture +def eagle_model_name(): + return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + + def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -95,3 +100,47 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm + + +def test_eagle_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, + eagle_model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using eagle speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "eagle", + "model": eagle_model_name, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 90% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.9 * len(ref_outputs)) + del spec_llm From 302b591ac37cf5c2fb19e896ae8099918aba2ecb Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 9 Apr 2025 09:07:15 -0700 Subject: [PATCH 14/17] less strict Signed-off-by: LiuXiaoxuanPKU --- tests/v1/e2e/test_spec_decode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index fcca4bdc220d..673714980592 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -140,7 +140,7 @@ def test_eagle_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 90% of the prompts to match exactly + # Heuristic: expect at least 70% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.9 * len(ref_outputs)) + assert matches > int(0.7 * len(ref_outputs)) del spec_llm From bcf2388929d8516a078df0b38626ff35b6b0c96b Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 9 Apr 2025 10:08:46 -0700 Subject: [PATCH 15/17] fix registry test Signed-off-by: LiuXiaoxuanPKU --- tests/models/registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 10b93460c56b..75a2b1e52db7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -367,6 +367,9 @@ def check_available_online( "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), + "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B-vllm"), # noqa: E501 } _TRANSFORMERS_MODELS = { From 29e76372fb5d8c3f8848980837c33e59a5495b23 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 9 Apr 2025 13:05:50 -0700 Subject: [PATCH 16/17] minor Signed-off-by: LiuXiaoxuanPKU --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 75a2b1e52db7..c137578ddf05 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -369,7 +369,7 @@ def check_available_online( trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B-vllm"), # noqa: E501 + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B"), # noqa: E501 } _TRANSFORMERS_MODELS = { From d511a0e2533422a217c4f0b916238555cec83cd8 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 9 Apr 2025 17:10:03 -0700 Subject: [PATCH 17/17] force tokenizer Signed-off-by: LiuXiaoxuanPKU --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index c137578ddf05..67784a5f8fb7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -369,7 +369,8 @@ def check_available_online( trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B"), # noqa: E501 + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 } _TRANSFORMERS_MODELS = {