77
88from vllm .config import VllmConfig
99from vllm .logger import init_logger
10+ from vllm .model_executor .layers .layernorm import RMSNorm
1011from vllm .model_executor .layers .logits_processor import LogitsProcessor
1112from vllm .model_executor .layers .sampler import SamplerOutput
1213from vllm .model_executor .layers .vocab_parallel_embedding import (
@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
5960 truncated_vocab_size < vocab_size. To use this technique, one has to find
6061 the top-k most frequent tokens in target dataset and add that as a tensor
6162 in the draft checkpoint (using key token_map). Also, the draft config
62- needs to have truncated_vocab_size (=k) as an attribute."""
63+ needs to have truncated_vocab_size (=k) as an attribute.
64+ 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
65+ module with regards to the use of additional RMS norms. The original
66+ EAGLE architecture 1) skips the pre-attention norm in its first
67+ transformer block, and 2) skips the final output norm, both of which we
68+ found to be suboptimal. We also add the support for separate norms
69+ applying to both the token embedding and hidden states before projection
70+ as in DeepSeek MTP, which we found to improve performance as well.
71+ """
6372
6473 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
6574 super ().__init__ ()
@@ -81,9 +90,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
8190 # While weights and biases are generally not needed,
8291 # they are retained here to support certain unit tests
8392 # (e.g., spec_decode/e2e/test_eagle_correctness.py).
84- self .model .model .layers [0 ].input_layernorm = DummyInputLayerNorm (
85- weight = self .model .model .layers [0 ].input_layernorm .weight )
86- self .model .model .norm = DummyOutputNorm ()
93+ if not hasattr (self .config .model ,
94+ "skip_prenorm" ) or self .config .model .skip_prenorm :
95+ self .model .model .layers [0 ].input_layernorm = DummyInputLayerNorm (
96+ weight = self .model .model .layers [0 ].input_layernorm .weight )
97+
98+ if not hasattr (
99+ self .config .model ,
100+ "skip_output_norm" ) or self .config .model .skip_output_norm :
101+ self .model .model .norm = DummyOutputNorm ()
102+
103+ self .add_para_norm = False
104+ if hasattr (self .config .model ,
105+ "add_para_norm" ) and self .config .model .add_para_norm :
106+ self .enorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
107+ self .hnorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
108+ self .add_para_norm = True
87109
88110 self .orig_vocab_size = config .vocab_size
89111 self .truncated_vocab_size = config .truncated_vocab_size
@@ -128,8 +150,17 @@ def forward(
128150 if inputs_embeds is None :
129151 inputs_embeds = self .get_input_embeddings (input_ids )
130152
131- inputs_embeds = self .fc (
132- torch .cat ([inputs_embeds , previous_hidden_states ], dim = - 1 ))
153+ if self .add_para_norm :
154+ inputs_embeds = torch .cat ([
155+ self .enorm (inputs_embeds ),
156+ self .hnorm (previous_hidden_states )
157+ ],
158+ dim = - 1 )
159+ else :
160+ inputs_embeds = torch .cat ([inputs_embeds , previous_hidden_states ],
161+ dim = - 1 )
162+
163+ inputs_embeds = self .fc (inputs_embeds )
133164
134165 inputs_embeds [positions == 0 ] = 0 # masking inputs at position=0
135166
@@ -190,6 +221,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
190221 else :
191222 logger .warning_once ("Found bias in the loaded weights but "
192223 "the model config doesn't have bias." )
224+ elif name .startswith ("enorm.weight" ):
225+ weight_loader = getattr (self .enorm .weight , "weight_loader" ,
226+ default_weight_loader )
227+ weight_loader (self .enorm .weight , loaded_weight )
228+ elif name .startswith ("hnorm.weight" ):
229+ weight_loader = getattr (self .hnorm .weight , "weight_loader" ,
230+ default_weight_loader )
231+ weight_loader (self .hnorm .weight , loaded_weight )
193232 elif name .startswith ("model.lm_head." ) or name .startswith (
194233 "model.model." ):
195234 model_weights [name .split ("model." , 1 )[- 1 ]] = loaded_weight
0 commit comments