5252from vllm .model_executor .model_loader .weight_utils import default_weight_loader
5353from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
5454from vllm .model_executor .models .utils import (
55- AutoWeightsLoader , is_pp_missing_parameter ,
55+ AutoWeightsLoader , extract_layer_index , is_pp_missing_parameter ,
5656 make_empty_intermediate_tensors_factory , make_layers , maybe_prefix )
5757from vllm .model_executor .sampling_metadata import SamplingMetadata
5858from vllm .sequence import IntermediateTensors
59+ from vllm .transformers_utils .configs import Olmo3Config
5960
6061
6162class Olmo2Attention (nn .Module ):
@@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module):
6869 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
6970 super ().__init__ ()
7071 self .config = vllm_config .model_config .hf_config
71- assert isinstance (self .config , Olmo2Config )
72+ assert isinstance (self .config , ( Olmo2Config , Olmo3Config ) )
7273
7374 hidden_size = self .config .hidden_size
7475 self .tp_size = get_tensor_model_parallel_world_size ()
@@ -111,22 +112,35 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
111112 self .q_norm = RMSNorm (self .config .hidden_size ,
112113 eps = self .config .rms_norm_eps )
113114
114- # Rotary embeddings.
115- self .rotary_emb = get_rope (
116- self .head_dim ,
117- rotary_dim = self .head_dim ,
118- max_position = self .max_position_embeddings ,
119- base = self .rope_theta , # type: ignore
120- )
121115 self .scaling = self .head_dim ** - 0.5
116+
117+ layer_idx = extract_layer_index (prefix )
118+ sliding_window = None
119+ if ((layer_types := getattr (self .config , "layer_types" , None ))
120+ is not None and layer_types [layer_idx ] == "sliding_attention" ):
121+ sliding_window = self .config .sliding_window
122+
122123 self .attn = Attention (
123124 self .num_heads ,
124125 self .head_dim ,
125126 self .scaling ,
126127 num_kv_heads = self .num_kv_heads ,
127128 cache_config = vllm_config .cache_config ,
128129 quant_config = vllm_config .quant_config ,
129- prefix = prefix ,
130+ per_layer_sliding_window = sliding_window ,
131+ prefix = f"{ prefix } .attn" ,
132+ )
133+
134+ # Rotary embeddings. Rope scaling is only applied on full attention
135+ # layers.
136+ self .rope_scaling = (self .config .rope_scaling
137+ if sliding_window is None else None )
138+ self .rotary_emb = get_rope (
139+ self .head_dim ,
140+ rotary_dim = self .head_dim ,
141+ max_position = self .max_position_embeddings ,
142+ base = self .rope_theta , # type: ignore
143+ rope_scaling = self .rope_scaling ,
130144 )
131145
132146 # Attention output projection.
@@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module):
176190 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
177191 super ().__init__ ()
178192 config = vllm_config .model_config .hf_config
179- assert isinstance (config , Olmo2Config )
193+ assert isinstance (config , ( Olmo2Config , Olmo3Config ) )
180194 hidden_size = config .hidden_size
181195 intermediate_size = config .intermediate_size
182196
@@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module):
221235 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
222236 super ().__init__ ()
223237 config = vllm_config .model_config .hf_config
224- assert isinstance (config , Olmo2Config )
238+ assert isinstance (config , ( Olmo2Config , Olmo3Config ) )
225239 # Attention block.
226240 self .self_attn = Olmo2Attention (vllm_config = vllm_config ,
227241 prefix = f"{ prefix } .self_attn" )
@@ -261,7 +275,7 @@ class Olmo2Model(nn.Module):
261275 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
262276 super ().__init__ ()
263277 self .config = vllm_config .model_config .hf_config
264- assert isinstance (self .config , Olmo2Config )
278+ assert isinstance (self .config , ( Olmo2Config , Olmo3Config ) )
265279
266280 self .embed_tokens = VocabParallelEmbedding (
267281 self .config .vocab_size ,
@@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
376390 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
377391 super ().__init__ ()
378392 config = vllm_config .model_config .hf_config
379- assert isinstance (config , Olmo2Config )
393+ assert isinstance (config , ( Olmo2Config , Olmo3Config ) )
380394 self .config = config
381395 self .model = Olmo2Model (vllm_config = vllm_config ,
382396 prefix = maybe_prefix (prefix , "model" ))
0 commit comments