From ef8d2325fb5d991d9b7a6968881dd8a199e582fd Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 5 Aug 2025 20:08:13 +0000 Subject: [PATCH 1/2] Polish Eagle3 config and model definitions --- .../convert/eagle/eagle3_converter.py | 2 - src/speculators/models/eagle.py | 111 ++++++----- src/speculators/models/eagle3.py | 184 ++++-------------- 3 files changed, 106 insertions(+), 191 deletions(-) diff --git a/src/speculators/convert/eagle/eagle3_converter.py b/src/speculators/convert/eagle/eagle3_converter.py index 7f643483..167a4a1d 100644 --- a/src/speculators/convert/eagle/eagle3_converter.py +++ b/src/speculators/convert/eagle/eagle3_converter.py @@ -178,9 +178,7 @@ def _build_eagle3_speculator_config( return Eagle3SpeculatorConfig( transformer_layer_config=transformer_config, speculators_config=speculators_config, - draft_vocab_size=eagle_config.get("draft_vocab_size", 32000), norm_before_residual=norm_before_residual, - target_hidden_size=eagle_config.get("target_hidden_size"), ) def _create_transformer_config_from_eagle( diff --git a/src/speculators/models/eagle.py b/src/speculators/models/eagle.py index b7cfc086..c987f9d1 100644 --- a/src/speculators/models/eagle.py +++ b/src/speculators/models/eagle.py @@ -18,7 +18,13 @@ from typing import Any, ClassVar, Literal, Optional, Union import torch -from pydantic import Field, field_serializer, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + field_serializer, + field_validator, + model_validator, +) from torch import nn from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -32,11 +38,66 @@ __all__ = [ "EagleSpeculator", "EagleSpeculatorConfig", + "TransformerLayerConfigMixin", ] +class TransformerLayerConfigMixin(BaseModel): + transformer_layer_config: PretrainedConfig = Field( + default_factory=LlamaConfig, + description=( + "Configuration object for the transformer layer architecture. " + "Must be a PretrainedConfig instance that matches the requirements " + "of the transformer_layer_architecture. Contains parameters such as " + "hidden_size, num_attention_heads, intermediate_size, vocab_size, " + "and other architecture-specific settings. " + "Additionally, it contains all the necessary information to check and " + "validate compatibility between the speculator and verifier models, " + "such as the vocab_size used for the speculator and the hidden_size " + "used for the speculator's transformer layer, which must match " + "the verifier's hidden_size according to the algorithm's design." + ), + ) + + @field_serializer("transformer_layer_config") + def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict: + """ + Serialize the transformer_layer_config to a dictionary for JSON storage. + + Converts the PretrainedConfig object to its dictionary representation + using to_diff_dict() to only include non-default values. + + :param value: The PretrainedConfig instance to serialize + :return: Dictionary representation of the transformer layer configuration + """ + return value.to_diff_dict() + + @field_validator("transformer_layer_config", mode="before") + @classmethod + def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig: + """ + Validate and convert transformer_layer_config to a PretrainedConfig instance. + + Accepts either a dictionary that can be converted to a PretrainedConfig + or an existing PretrainedConfig instance. + + :param value: The value to validate (dict or PretrainedConfig) + :return: A validated PretrainedConfig instance + :raises ValueError: If the value cannot be converted to a PretrainedConfig + """ + if isinstance(value, dict): + return AutoConfig.for_model(**value) + if isinstance(value, PretrainedConfig): + return value + + raise ValueError( + "transformer_layer_config must be a PretrainedConfig instance or a " + "dictionary that can be converted to a PretrainedConfig." + ) + + @SpeculatorModelConfig.register("eagle") -class EagleSpeculatorConfig(SpeculatorModelConfig): +class EagleSpeculatorConfig(SpeculatorModelConfig, TransformerLayerConfigMixin): """ A SpeculatorModelConfig implementation to be used with the EagleSpeculator for EAGLE and HASS variants for spec decoding: @@ -91,16 +152,6 @@ class EagleSpeculatorConfig(SpeculatorModelConfig): "transformer decoder layer class (e.g., 'LlamaDecoderLayer')." ), ) - transformer_layer_config: PretrainedConfig = Field( - default_factory=LlamaConfig, - description=( - "Configuration object for the transformer layer architecture. " - "Must be a PretrainedConfig instance that matches the requirements " - "of the transformer_layer_architecture. Contains parameters such as " - "hidden_size, num_attention_heads, intermediate_size, vocab_size, " - "and other architecture-specific settings." - ), - ) layernorms: bool = Field( default=False, description=( @@ -140,42 +191,6 @@ def check_add_architectures(self) -> Self: return self - @field_serializer("transformer_layer_config") - def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict: - """ - Serialize the transformer_layer_config to a dictionary for JSON storage. - - Converts the PretrainedConfig object to its dictionary representation - using to_diff_dict() to only include non-default values. - - :param value: The PretrainedConfig instance to serialize - :return: Dictionary representation of the transformer layer configuration - """ - return value.to_diff_dict() - - @field_validator("transformer_layer_config", mode="before") - @classmethod - def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig: - """ - Validate and convert transformer_layer_config to a PretrainedConfig instance. - - Accepts either a dictionary that can be converted to a PretrainedConfig - or an existing PretrainedConfig instance. - - :param value: The value to validate (dict or PretrainedConfig) - :return: A validated PretrainedConfig instance - :raises ValueError: If the value cannot be converted to a PretrainedConfig - """ - if isinstance(value, dict): - return AutoConfig.for_model(**value) - if isinstance(value, PretrainedConfig): - return value - - raise ValueError( - "transformer_layer_config must be a PretrainedConfig instance or a " - "dictionary that can be converted to a PretrainedConfig." - ) - @SpeculatorModel.register("eagle") class EagleSpeculator(SpeculatorModel): diff --git a/src/speculators/models/eagle3.py b/src/speculators/models/eagle3.py index b357cbc2..defc2128 100644 --- a/src/speculators/models/eagle3.py +++ b/src/speculators/models/eagle3.py @@ -13,15 +13,14 @@ """ import os -from typing import Any, ClassVar, Literal, Optional, Union +from typing import ClassVar, Literal, Optional, Union import torch -from pydantic import Field, field_serializer, field_validator +from pydantic import Field from torch import nn from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaMLP, LlamaRMSNorm, @@ -30,6 +29,7 @@ ) from speculators import SpeculatorModel, SpeculatorModelConfig +from speculators.models.eagle import TransformerLayerConfigMixin __all__ = [ "Eagle3Attention", @@ -40,7 +40,7 @@ @SpeculatorModelConfig.register("eagle3") -class Eagle3SpeculatorConfig(SpeculatorModelConfig): +class Eagle3SpeculatorConfig(SpeculatorModelConfig, TransformerLayerConfigMixin): """ Configuration for EAGLE-3 speculator with vocabulary mapping. @@ -57,52 +57,11 @@ class Eagle3SpeculatorConfig(SpeculatorModelConfig): default_factory=lambda: ["Eagle3Speculator"], description="Model architectures that can load these weights", ) - - transformer_layer_config: PretrainedConfig = Field( - default_factory=LlamaConfig, - description="Configuration for the transformer decoder layer", - ) - - draft_vocab_size: int = Field( - default=32000, - description="Size of draft model vocabulary for speculation", - ) - norm_before_residual: bool = Field( default=False, description="Apply hidden_norm before storing residual", ) - target_hidden_size: Optional[int] = Field( - default=None, - description="Hidden size of the target model (if different from draft model)", - ) - - @property - def target_vocab_size(self) -> int: - """Get target vocabulary size from transformer config.""" - return self.transformer_layer_config.vocab_size - - @field_serializer("transformer_layer_config") - def serialize_transformer_config(self, value: PretrainedConfig) -> dict: - """Serialize transformer config to dict.""" - return value.to_diff_dict() - - @field_validator("transformer_layer_config", mode="before") - @classmethod - def validate_transformer_config(cls, value: Any) -> PretrainedConfig: - """Validate and convert transformer config.""" - if isinstance(value, dict): - config_class: type[PretrainedConfig] = LlamaConfig - if "model_type" in value: - from transformers import AutoConfig - - config_class = AutoConfig.for_model( - model_type=value["model_type"] - ).__class__ - return config_class(**value) - return value - class Eagle3Attention(nn.Module): """ @@ -320,6 +279,12 @@ class Eagle3Speculator(SpeculatorModel): EAGLE-3 processes concatenated hidden states from multiple verifier layers through a fusion layer, then combines with embeddings for a custom decoder layer that accepts 2x hidden_size input. + + Future work includes: + - Implementing verifier attachment logic to load embeddings and heads + from verifier models. + - Enabling lm_head pruning and remapping to target vocabulary. + - Adding support for generic transformer layer configurations """ config_class: ClassVar[type[Eagle3SpeculatorConfig]] = Eagle3SpeculatorConfig # type: ignore[misc] @@ -345,22 +310,16 @@ def __init__( """ if not isinstance(config, Eagle3SpeculatorConfig): raise ValueError( - f"config must be Eagle3SpeculatorConfig, got {type(config)}" + "config must be an instance of Eagle3SpeculatorConfig, " + f"got {type(config)} instead." ) - self.config: Eagle3SpeculatorConfig = config - + # Initialize model parameters from config self.hidden_size = config.transformer_layer_config.hidden_size - self.draft_vocab_size = config.draft_vocab_size - self.target_vocab_size = config.target_vocab_size - - # Use target_hidden_size if specified, otherwise use draft model's hidden_size - self.target_hidden_size = ( - config.target_hidden_size - if config.target_hidden_size is not None - else self.hidden_size - ) + self.target_hidden_size = self.hidden_size // 2 + self.vocab_size = config.transformer_layer_config.vocab_size + # Delayed initialization to ensure everything needed for attach_verifier is set super().__init__( config=config, verifier=verifier, @@ -368,19 +327,19 @@ def __init__( ) self.embed_tokens = nn.Embedding( - self.target_vocab_size, + self.vocab_size, self.hidden_size, - padding_idx=config.transformer_layer_config.pad_token_id - if hasattr(config.transformer_layer_config, "pad_token_id") - else None, + padding_idx=( + config.transformer_layer_config.pad_token_id + if hasattr(config.transformer_layer_config, "pad_token_id") + else None + ), ) - self.fc = nn.Linear( - 3 * self.target_hidden_size, # Use target model's hidden size + 3 * self.hidden_size, # Use target model's hidden size self.hidden_size, bias=False, ) - self.layers = nn.ModuleList( [ Eagle3DecoderLayer( @@ -390,32 +349,35 @@ def __init__( ) ] ) - self.norm = LlamaRMSNorm( self.hidden_size, eps=config.transformer_layer_config.rms_norm_eps, ) - self.lm_head = nn.Linear( self.hidden_size, - self.draft_vocab_size, + self.vocab_size, bias=False, ) - self.register_buffer( # type: ignore[attr-defined] - "d2t", - torch.zeros(self.draft_vocab_size, dtype=torch.long), - ) - self.register_buffer( # type: ignore[attr-defined] - "t2d", - torch.zeros(self.target_vocab_size, dtype=torch.bool), - ) + self.post_init() # type: ignore[attr-defined] - # Type hints for buffers - self.d2t: torch.Tensor - self.t2d: torch.Tensor + def attach_verifier( + self, + verifier: Union[str, os.PathLike, PreTrainedModel], + mode: Optional[Literal["full", "train_only"]] = None, + ) -> PreTrainedModel: + """ + no-op currently, need to determine when to load embedding and head from + verifier based on first creation of the model and subsequent loading. - self.post_init() # type: ignore[attr-defined] + :param verifier: Verifier model to attach + :param mode: Attachment mode, currently not used + :return: The attached verifier model + """ + return super().attach_verifier( + verifier=verifier, + mode=mode, + ) def forward( self, @@ -447,14 +409,11 @@ def forward( return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) - inputs_embeds = self.embed_tokens(input_ids) - fused_hidden = self.fc(hidden_states) - layer_input = torch.cat([inputs_embeds, fused_hidden], dim=-1) - batch_size, seq_length = layer_input.shape[:2] + if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004 past_key_values_length = ( past_key_values[0][0].shape[2] if past_key_values else 0 @@ -484,12 +443,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, ) - hidden_states = layer_outputs[0] - - hidden_states = self.norm(hidden_states) - - logits = self.compute_logits(hidden_states, map_to_target_vocab=True) + logits = self.norm(hidden_states) if not return_dict: return logits @@ -500,56 +455,3 @@ def forward( hidden_states=None, attentions=None, ) - - def compute_logits( - self, - hidden_states: torch.FloatTensor, - map_to_target_vocab: bool = True, - ) -> torch.FloatTensor: - """ - Compute logits with optional vocabulary mapping. - - :param hidden_states: Hidden states from the model - :param map_to_target_vocab: Whether to map draft logits to target vocabulary - :return: Logits tensor - """ - logits = self.lm_head(hidden_states) - - if not map_to_target_vocab: - return logits - - batch_size, seq_length, _ = logits.shape - - draft_indices = torch.arange(self.draft_vocab_size, device=logits.device) - - target_indices = draft_indices + self.d2t - - mapped_logits = logits.new_full( - (batch_size, seq_length, self.target_vocab_size), float("-inf") - ) - - mapped_logits[:, :, target_indices] = logits - - return mapped_logits - - def map_draft_to_target_tokens( - self, draft_tokens: torch.LongTensor - ) -> torch.LongTensor: - """ - Map draft token IDs to target token IDs. - - :param draft_tokens: Draft vocabulary token IDs - :return: Target vocabulary token IDs - """ - return draft_tokens + self.d2t[draft_tokens] # type: ignore[return-value] - - def check_target_token_availability( - self, target_tokens: torch.LongTensor - ) -> torch.BoolTensor: - """ - Check if target tokens have draft equivalents. - - :param target_tokens: Target vocabulary token IDs - :return: Boolean mask indicating availability in draft vocabulary - """ - return self.t2d[target_tokens] # type: ignore[return-value] From 824b12201e9b7b75a95d5d858f3ecf7482438ddc Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 5 Aug 2025 16:12:27 -0400 Subject: [PATCH 2/2] Update src/speculators/models/eagle3.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/speculators/models/eagle3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/models/eagle3.py b/src/speculators/models/eagle3.py index defc2128..d99660b8 100644 --- a/src/speculators/models/eagle3.py +++ b/src/speculators/models/eagle3.py @@ -444,7 +444,7 @@ def forward( use_cache=use_cache, ) hidden_states = layer_outputs[0] - logits = self.norm(hidden_states) + logits = self.lm_head(self.norm(hidden_states)) if not return_dict: return logits