1717from collections .abc import Iterable
1818from functools import partial
1919from itertools import islice
20- from typing import Any , Optional , Union
20+ from typing import Optional , Union
2121
2222import torch
2323from torch import nn
24- from transformers import OlmoeConfig
2524
2625from vllm .attention import Attention
2726from vllm .compilation .decorators import support_torch_compile
28- from vllm .config import CacheConfig , VllmConfig
27+ from vllm .config import VllmConfig
2928from vllm .distributed import (
3029 get_pp_group ,
3130 get_tensor_model_parallel_rank ,
@@ -117,20 +116,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117116
118117
119118class OlmoeAttention (nn .Module ):
120- def __init__ (
121- self ,
122- hidden_size : int ,
123- num_heads : int ,
124- num_kv_heads : int ,
125- rope_theta : float = 10000 ,
126- rope_scaling : Optional [dict [str , Any ]] = None ,
127- max_position_embeddings : int = 4096 ,
128- cache_config : Optional [CacheConfig ] = None ,
129- quant_config : Optional [QuantizationConfig ] = None ,
130- prefix : str = "" ,
131- ) -> None :
119+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
132120 super ().__init__ ()
133- self .hidden_size = hidden_size
121+
122+ config = vllm_config .model_config .hf_config
123+ cache_config = vllm_config .cache_config
124+ quant_config = vllm_config .quant_config
125+
126+ self .hidden_size = config .hidden_size
127+ rope_theta = getattr (config , "rope_theta" , 10000 )
128+ rope_scaling = getattr (config , "rope_scaling" , None )
129+ max_position_embeddings = getattr (config , "max_position_embeddings" , 4096 )
130+
131+ num_heads = config .num_attention_heads
132+ num_kv_heads = config .num_key_value_heads
133+
134134 tp_size = get_tensor_model_parallel_world_size ()
135135 self .total_num_heads = num_heads
136136 assert self .total_num_heads % tp_size == 0
@@ -145,15 +145,15 @@ def __init__(
145145 # the KV heads across multiple tensor parallel GPUs.
146146 assert tp_size % self .total_num_kv_heads == 0
147147 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
148- self .head_dim = hidden_size // self .total_num_heads
148+ self .head_dim = self . hidden_size // self .total_num_heads
149149 self .q_size = self .num_heads * self .head_dim
150150 self .kv_size = self .num_kv_heads * self .head_dim
151151 self .scaling = self .head_dim ** - 0.5
152152 self .rope_theta = rope_theta
153153 self .max_position_embeddings = max_position_embeddings
154154
155155 self .qkv_proj = QKVParallelLinear (
156- hidden_size ,
156+ self . hidden_size ,
157157 self .head_dim ,
158158 self .total_num_heads ,
159159 self .total_num_kv_heads ,
@@ -166,7 +166,7 @@ def __init__(
166166 self .k_norm = RMSNorm (self .total_num_kv_heads * self .head_dim , eps = 1e-5 )
167167 self .o_proj = RowParallelLinear (
168168 self .total_num_heads * self .head_dim ,
169- hidden_size ,
169+ self . hidden_size ,
170170 bias = False ,
171171 quant_config = quant_config ,
172172 )
@@ -218,28 +218,15 @@ def forward(
218218
219219
220220class OlmoeDecoderLayer (nn .Module ):
221- def __init__ (
222- self ,
223- config : OlmoeConfig ,
224- cache_config : Optional [CacheConfig ] = None ,
225- quant_config : Optional [QuantizationConfig ] = None ,
226- prefix : str = "" ,
227- ) -> None :
221+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
228222 super ().__init__ ()
223+ config = vllm_config .model_config .hf_config
224+ quant_config = vllm_config .quant_config
225+
229226 self .hidden_size = config .hidden_size
230- rope_theta = getattr (config , "rope_theta" , 10000 )
231- rope_scaling = getattr (config , "rope_scaling" , None )
232- max_position_embeddings = getattr (config , "max_position_embeddings" , 4096 )
233227
234228 self .self_attn = OlmoeAttention (
235- hidden_size = self .hidden_size ,
236- num_heads = config .num_attention_heads ,
237- num_kv_heads = config .num_key_value_heads ,
238- rope_theta = rope_theta ,
239- rope_scaling = rope_scaling ,
240- max_position_embeddings = max_position_embeddings ,
241- cache_config = cache_config ,
242- quant_config = quant_config ,
229+ vllm_config = vllm_config ,
243230 prefix = f"{ prefix } .self_attn" ,
244231 )
245232
@@ -280,12 +267,16 @@ def forward(
280267
281268@support_torch_compile
282269class OlmoeModel (nn .Module ):
283- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
270+ def __init__ (
271+ self ,
272+ * ,
273+ vllm_config : VllmConfig ,
274+ prefix : str = "" ,
275+ layer_type : type [nn .Module ] = OlmoeDecoderLayer ,
276+ ):
284277 super ().__init__ ()
285278
286279 config = vllm_config .model_config .hf_config
287- cache_config = vllm_config .cache_config
288- quant_config = vllm_config .quant_config
289280
290281 self .vocab_size = config .vocab_size
291282 self .config = config
@@ -295,9 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
295286 )
296287 self .start_layer , self .end_layer , self .layers = make_layers (
297288 config .num_hidden_layers ,
298- lambda prefix : OlmoeDecoderLayer (
299- config , cache_config , quant_config , prefix = prefix
300- ),
289+ lambda prefix : layer_type (vllm_config = vllm_config , prefix = prefix ),
301290 prefix = f"{ prefix } .layers" ,
302291 )
303292 self .norm = RMSNorm (config .hidden_size , eps = 1e-5 )
@@ -339,7 +328,10 @@ def forward(
339328 {"hidden_states" : hidden_states , "residual" : residual }
340329 )
341330
342- hidden_states , _ = self .norm (hidden_states , residual )
331+ if residual is not None :
332+ hidden_states , _ = self .norm (hidden_states , residual )
333+ else :
334+ hidden_states = self .norm (hidden_states )
343335 return hidden_states
344336
345337 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
@@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
455447 ],
456448 }
457449
458- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
450+ def __init__ (
451+ self ,
452+ * ,
453+ vllm_config : VllmConfig ,
454+ prefix : str = "" ,
455+ layer_type : type [nn .Module ] = OlmoeDecoderLayer ,
456+ ):
459457 super ().__init__ ()
460458 config = vllm_config .model_config .hf_config
461459 quant_config = vllm_config .quant_config
462460 self .config = config
463461 self .quant_config = quant_config
464462 self .model = OlmoeModel (
465- vllm_config = vllm_config , prefix = maybe_prefix (prefix , "model" )
463+ vllm_config = vllm_config ,
464+ prefix = maybe_prefix (prefix , "model" ),
465+ layer_type = layer_type ,
466466 )
467467 self .lm_head = ParallelLMHead (
468468 config .vocab_size ,
0 commit comments