1616from collections .abc import Iterable
1717from functools import partial
1818from itertools import islice
19- from typing import Any , Optional , Union
19+ from typing import Optional , Union
2020
2121import torch
2222from torch import nn
23- from transformers import OlmoeConfig
2423
2524from vllm .attention import Attention
2625from vllm .compilation .decorators import support_torch_compile
27- from vllm .config import CacheConfig , VllmConfig
26+ from vllm .config import VllmConfig
2827from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
2928 get_tensor_model_parallel_world_size ,
3029 tensor_model_parallel_all_gather )
@@ -103,20 +102,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103102
104103class OlmoeAttention (nn .Module ):
105104
106- def __init__ (
107- self ,
108- hidden_size : int ,
109- num_heads : int ,
110- num_kv_heads : int ,
111- rope_theta : float = 10000 ,
112- rope_scaling : Optional [dict [str , Any ]] = None ,
113- max_position_embeddings : int = 4096 ,
114- cache_config : Optional [CacheConfig ] = None ,
115- quant_config : Optional [QuantizationConfig ] = None ,
116- prefix : str = "" ,
117- ) -> None :
105+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
118106 super ().__init__ ()
119- self .hidden_size = hidden_size
107+
108+ config = vllm_config .model_config .hf_config
109+ cache_config = vllm_config .cache_config
110+ quant_config = vllm_config .quant_config
111+
112+ self .hidden_size = config .hidden_size
113+ rope_theta = getattr (config , "rope_theta" , 10000 )
114+ rope_scaling = getattr (config , "rope_scaling" , None )
115+ max_position_embeddings = getattr (config , "max_position_embeddings" ,
116+ 4096 )
117+
118+ num_heads = config .num_attention_heads
119+ num_kv_heads = config .num_key_value_heads
120+
120121 tp_size = get_tensor_model_parallel_world_size ()
121122 self .total_num_heads = num_heads
122123 assert self .total_num_heads % tp_size == 0
@@ -131,15 +132,15 @@ def __init__(
131132 # the KV heads across multiple tensor parallel GPUs.
132133 assert tp_size % self .total_num_kv_heads == 0
133134 self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
134- self .head_dim = hidden_size // self .total_num_heads
135+ self .head_dim = self . hidden_size // self .total_num_heads
135136 self .q_size = self .num_heads * self .head_dim
136137 self .kv_size = self .num_kv_heads * self .head_dim
137138 self .scaling = self .head_dim ** - 0.5
138139 self .rope_theta = rope_theta
139140 self .max_position_embeddings = max_position_embeddings
140141
141142 self .qkv_proj = QKVParallelLinear (
142- hidden_size ,
143+ self . hidden_size ,
143144 self .head_dim ,
144145 self .total_num_heads ,
145146 self .total_num_kv_heads ,
@@ -153,7 +154,7 @@ def __init__(
153154 eps = 1e-5 )
154155 self .o_proj = RowParallelLinear (
155156 self .total_num_heads * self .head_dim ,
156- hidden_size ,
157+ self . hidden_size ,
157158 bias = False ,
158159 quant_config = quant_config ,
159160 )
@@ -204,29 +205,15 @@ def forward(
204205
205206class OlmoeDecoderLayer (nn .Module ):
206207
207- def __init__ (
208- self ,
209- config : OlmoeConfig ,
210- cache_config : Optional [CacheConfig ] = None ,
211- quant_config : Optional [QuantizationConfig ] = None ,
212- prefix : str = "" ,
213- ) -> None :
208+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ) -> None :
214209 super ().__init__ ()
210+ config = vllm_config .model_config .hf_config
211+ quant_config = vllm_config .quant_config
212+
215213 self .hidden_size = config .hidden_size
216- rope_theta = getattr (config , "rope_theta" , 10000 )
217- rope_scaling = getattr (config , "rope_scaling" , None )
218- max_position_embeddings = getattr (config , "max_position_embeddings" ,
219- 4096 )
220214
221215 self .self_attn = OlmoeAttention (
222- hidden_size = self .hidden_size ,
223- num_heads = config .num_attention_heads ,
224- num_kv_heads = config .num_key_value_heads ,
225- rope_theta = rope_theta ,
226- rope_scaling = rope_scaling ,
227- max_position_embeddings = max_position_embeddings ,
228- cache_config = cache_config ,
229- quant_config = quant_config ,
216+ vllm_config = vllm_config ,
230217 prefix = f"{ prefix } .self_attn" ,
231218 )
232219
@@ -270,12 +257,14 @@ def forward(
270257@support_torch_compile
271258class OlmoeModel (nn .Module ):
272259
273- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
260+ def __init__ (self ,
261+ * ,
262+ vllm_config : VllmConfig ,
263+ prefix : str = "" ,
264+ layer_type : type [nn .Module ] = OlmoeDecoderLayer ):
274265 super ().__init__ ()
275266
276267 config = vllm_config .model_config .hf_config
277- cache_config = vllm_config .cache_config
278- quant_config = vllm_config .quant_config
279268
280269 self .vocab_size = config .vocab_size
281270 self .config = config
@@ -285,8 +274,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
285274 )
286275 self .start_layer , self .end_layer , self .layers = make_layers (
287276 config .num_hidden_layers ,
288- lambda prefix : OlmoeDecoderLayer (
289- config , cache_config , quant_config , prefix = prefix ),
277+ lambda prefix : layer_type (vllm_config = vllm_config , prefix = prefix ),
290278 prefix = f"{ prefix } .layers" )
291279 self .norm = RMSNorm (config .hidden_size , eps = 1e-5 )
292280
@@ -328,7 +316,10 @@ def forward(
328316 "residual" : residual
329317 })
330318
331- hidden_states , _ = self .norm (hidden_states , residual )
319+ if residual is not None :
320+ hidden_states , _ = self .norm (hidden_states , residual )
321+ else :
322+ hidden_states = self .norm (hidden_states )
332323 return hidden_states
333324
334325 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
@@ -440,14 +431,19 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
440431 ],
441432 }
442433
443- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
434+ def __init__ (self ,
435+ * ,
436+ vllm_config : VllmConfig ,
437+ prefix : str = "" ,
438+ layer_type : type [nn .Module ] = OlmoeDecoderLayer ):
444439 super ().__init__ ()
445440 config = vllm_config .model_config .hf_config
446441 quant_config = vllm_config .quant_config
447442 self .config = config
448443 self .quant_config = quant_config
449444 self .model = OlmoeModel (vllm_config = vllm_config ,
450- prefix = maybe_prefix (prefix , "model" ))
445+ prefix = maybe_prefix (prefix , "model" ),
446+ layer_type = layer_type )
451447 self .lm_head = ParallelLMHead (config .vocab_size ,
452448 config .hidden_size ,
453449 quant_config = quant_config ,
0 commit comments