Skip to content

Commit 8d2b8c0

Browse files
authored
[Model] Add FlexOlmo model implementation (#24923)
Signed-off-by: Shane A <shanea@allenai.org>
1 parent b2155ed commit 8d2b8c0

File tree

8 files changed

+286
-46
lines changed

8 files changed

+286
-46
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ th {
363363
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
364364
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
365365
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
366+
| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ |
366367
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
367368
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
368369
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def check_available_online(
250250
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
251251
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
252252
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
253+
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
253254
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
254255
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
255256
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Inference-only FlexOlmo model compatible with HuggingFace weights."""
16+
17+
from typing import Optional
18+
19+
import torch
20+
from torch import nn
21+
22+
from vllm.config import VllmConfig
23+
from vllm.distributed import get_tensor_model_parallel_world_size
24+
from vllm.logger import init_logger
25+
from vllm.model_executor.layers.fused_moe import FusedMoE
26+
from vllm.model_executor.layers.layernorm import RMSNorm
27+
from vllm.model_executor.layers.linear import ReplicatedLinear
28+
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
29+
from vllm.transformers_utils.configs import FlexOlmoConfig
30+
31+
logger = init_logger(__name__)
32+
33+
34+
class FlexOlmoAttention(OlmoeAttention):
35+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
36+
super().__init__(vllm_config=vllm_config, prefix=prefix)
37+
38+
hf_config = vllm_config.model_config.hf_config
39+
assert isinstance(hf_config, FlexOlmoConfig)
40+
41+
self.k_norm = RMSNorm(
42+
self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps
43+
)
44+
self.q_norm = RMSNorm(
45+
self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps
46+
)
47+
48+
49+
class FlexOlmoMoE(nn.Module):
50+
"""A tensor-parallel MoE implementation for FlexOlmo that shards each expert
51+
across all ranks.
52+
53+
Each expert's weights are sharded across all ranks and a fused MoE
54+
kernel is used for the forward pass, and finally we reduce the outputs
55+
across ranks.
56+
"""
57+
58+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
59+
super().__init__()
60+
61+
hf_config = vllm_config.model_config.hf_config
62+
assert isinstance(hf_config, FlexOlmoConfig)
63+
64+
tp_size = get_tensor_model_parallel_world_size()
65+
66+
# Gate always runs at half / full precision for now.
67+
self.gate = ReplicatedLinear(
68+
hf_config.hidden_size,
69+
hf_config.num_experts,
70+
bias=False,
71+
return_bias=False,
72+
quant_config=None,
73+
prefix=f"{prefix}.gate",
74+
)
75+
76+
# Gate always runs at half / full precision for now.
77+
self.experts = FusedMoE(
78+
num_experts=hf_config.num_experts,
79+
top_k=hf_config.num_experts_per_tok,
80+
hidden_size=hf_config.hidden_size,
81+
intermediate_size=hf_config.intermediate_size,
82+
reduce_results=True,
83+
renormalize=False,
84+
quant_config=None,
85+
tp_size=tp_size,
86+
prefix=f"{prefix}.experts",
87+
)
88+
89+
self.top_k = hf_config.num_experts_per_tok
90+
91+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
92+
# NOTE: hidden_states can have either 1D or 2D shape.
93+
orig_shape = hidden_states.shape
94+
hidden_dim = hidden_states.shape[-1]
95+
hidden_states = hidden_states.view(-1, hidden_dim)
96+
97+
# router_logits: (num_tokens, n_experts)
98+
router_logits = self.gate(hidden_states)
99+
# Warning: The experts mutate the hidden state input! This messes up
100+
# basic things like the residual stream.
101+
final_hidden_states = self.experts(
102+
hidden_states=hidden_states.detach().clone(),
103+
router_logits=router_logits.float(),
104+
)
105+
106+
return final_hidden_states.view(orig_shape)
107+
108+
109+
class FlexOlmoDecoderLayer(nn.Module):
110+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
111+
super().__init__()
112+
hf_config = vllm_config.model_config.hf_config
113+
assert isinstance(hf_config, FlexOlmoConfig)
114+
115+
self.self_attn = FlexOlmoAttention(
116+
vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
117+
)
118+
self.post_attention_layernorm = RMSNorm(
119+
hf_config.hidden_size, eps=hf_config.rms_norm_eps
120+
)
121+
self.post_feedforward_layernorm = RMSNorm(
122+
hf_config.hidden_size, eps=hf_config.rms_norm_eps
123+
)
124+
125+
self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp")
126+
127+
def forward(
128+
self,
129+
positions: torch.Tensor,
130+
hidden_states: torch.Tensor,
131+
residual: Optional[torch.Tensor],
132+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
133+
# Attention block.
134+
residual = hidden_states
135+
hidden_states = self.self_attn(positions, hidden_states)
136+
hidden_states = self.post_attention_layernorm(hidden_states)
137+
hidden_states = hidden_states + residual
138+
139+
# MLP block.
140+
residual = hidden_states
141+
hidden_states = self.mlp(hidden_states)
142+
hidden_states = self.post_feedforward_layernorm(hidden_states)
143+
hidden_states = residual + hidden_states
144+
return hidden_states, None
145+
146+
147+
class FlexOlmoForCausalLM(OlmoeForCausalLM):
148+
fall_back_to_pt_during_load = False
149+
150+
def __init__(
151+
self,
152+
*,
153+
vllm_config: VllmConfig,
154+
prefix: str = "",
155+
layer_type: type[nn.Module] = FlexOlmoDecoderLayer,
156+
):
157+
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

vllm/model_executor/models/olmoe.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
from collections.abc import Iterable
1818
from functools import partial
1919
from itertools import islice
20-
from typing import Any, Optional, Union
20+
from typing import Optional, Union
2121

2222
import torch
2323
from torch import nn
24-
from transformers import OlmoeConfig
2524

2625
from vllm.attention import Attention
2726
from vllm.compilation.decorators import support_torch_compile
28-
from vllm.config import CacheConfig, VllmConfig
27+
from vllm.config import VllmConfig
2928
from vllm.distributed import (
3029
get_pp_group,
3130
get_tensor_model_parallel_rank,
@@ -117,20 +116,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117116

118117

119118
class OlmoeAttention(nn.Module):
120-
def __init__(
121-
self,
122-
hidden_size: int,
123-
num_heads: int,
124-
num_kv_heads: int,
125-
rope_theta: float = 10000,
126-
rope_scaling: Optional[dict[str, Any]] = None,
127-
max_position_embeddings: int = 4096,
128-
cache_config: Optional[CacheConfig] = None,
129-
quant_config: Optional[QuantizationConfig] = None,
130-
prefix: str = "",
131-
) -> None:
119+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
132120
super().__init__()
133-
self.hidden_size = hidden_size
121+
122+
config = vllm_config.model_config.hf_config
123+
cache_config = vllm_config.cache_config
124+
quant_config = vllm_config.quant_config
125+
126+
self.hidden_size = config.hidden_size
127+
rope_theta = getattr(config, "rope_theta", 10000)
128+
rope_scaling = getattr(config, "rope_scaling", None)
129+
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
130+
131+
num_heads = config.num_attention_heads
132+
num_kv_heads = config.num_key_value_heads
133+
134134
tp_size = get_tensor_model_parallel_world_size()
135135
self.total_num_heads = num_heads
136136
assert self.total_num_heads % tp_size == 0
@@ -145,15 +145,15 @@ def __init__(
145145
# the KV heads across multiple tensor parallel GPUs.
146146
assert tp_size % self.total_num_kv_heads == 0
147147
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
148-
self.head_dim = hidden_size // self.total_num_heads
148+
self.head_dim = self.hidden_size // self.total_num_heads
149149
self.q_size = self.num_heads * self.head_dim
150150
self.kv_size = self.num_kv_heads * self.head_dim
151151
self.scaling = self.head_dim**-0.5
152152
self.rope_theta = rope_theta
153153
self.max_position_embeddings = max_position_embeddings
154154

155155
self.qkv_proj = QKVParallelLinear(
156-
hidden_size,
156+
self.hidden_size,
157157
self.head_dim,
158158
self.total_num_heads,
159159
self.total_num_kv_heads,
@@ -166,7 +166,7 @@ def __init__(
166166
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
167167
self.o_proj = RowParallelLinear(
168168
self.total_num_heads * self.head_dim,
169-
hidden_size,
169+
self.hidden_size,
170170
bias=False,
171171
quant_config=quant_config,
172172
)
@@ -218,28 +218,15 @@ def forward(
218218

219219

220220
class OlmoeDecoderLayer(nn.Module):
221-
def __init__(
222-
self,
223-
config: OlmoeConfig,
224-
cache_config: Optional[CacheConfig] = None,
225-
quant_config: Optional[QuantizationConfig] = None,
226-
prefix: str = "",
227-
) -> None:
221+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
228222
super().__init__()
223+
config = vllm_config.model_config.hf_config
224+
quant_config = vllm_config.quant_config
225+
229226
self.hidden_size = config.hidden_size
230-
rope_theta = getattr(config, "rope_theta", 10000)
231-
rope_scaling = getattr(config, "rope_scaling", None)
232-
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
233227

234228
self.self_attn = OlmoeAttention(
235-
hidden_size=self.hidden_size,
236-
num_heads=config.num_attention_heads,
237-
num_kv_heads=config.num_key_value_heads,
238-
rope_theta=rope_theta,
239-
rope_scaling=rope_scaling,
240-
max_position_embeddings=max_position_embeddings,
241-
cache_config=cache_config,
242-
quant_config=quant_config,
229+
vllm_config=vllm_config,
243230
prefix=f"{prefix}.self_attn",
244231
)
245232

@@ -280,12 +267,16 @@ def forward(
280267

281268
@support_torch_compile
282269
class OlmoeModel(nn.Module):
283-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
270+
def __init__(
271+
self,
272+
*,
273+
vllm_config: VllmConfig,
274+
prefix: str = "",
275+
layer_type: type[nn.Module] = OlmoeDecoderLayer,
276+
):
284277
super().__init__()
285278

286279
config = vllm_config.model_config.hf_config
287-
cache_config = vllm_config.cache_config
288-
quant_config = vllm_config.quant_config
289280

290281
self.vocab_size = config.vocab_size
291282
self.config = config
@@ -295,9 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
295286
)
296287
self.start_layer, self.end_layer, self.layers = make_layers(
297288
config.num_hidden_layers,
298-
lambda prefix: OlmoeDecoderLayer(
299-
config, cache_config, quant_config, prefix=prefix
300-
),
289+
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
301290
prefix=f"{prefix}.layers",
302291
)
303292
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
@@ -339,7 +328,10 @@ def forward(
339328
{"hidden_states": hidden_states, "residual": residual}
340329
)
341330

342-
hidden_states, _ = self.norm(hidden_states, residual)
331+
if residual is not None:
332+
hidden_states, _ = self.norm(hidden_states, residual)
333+
else:
334+
hidden_states = self.norm(hidden_states)
343335
return hidden_states
344336

345337
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
455447
],
456448
}
457449

458-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
450+
def __init__(
451+
self,
452+
*,
453+
vllm_config: VllmConfig,
454+
prefix: str = "",
455+
layer_type: type[nn.Module] = OlmoeDecoderLayer,
456+
):
459457
super().__init__()
460458
config = vllm_config.model_config.hf_config
461459
quant_config = vllm_config.quant_config
462460
self.config = config
463461
self.quant_config = quant_config
464462
self.model = OlmoeModel(
465-
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
463+
vllm_config=vllm_config,
464+
prefix=maybe_prefix(prefix, "model"),
465+
layer_type=layer_type,
466466
)
467467
self.lm_head = ParallelLMHead(
468468
config.vocab_size,

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
9191
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
9292
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
93+
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
9394
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
9495
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
9596
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),

vllm/transformers_utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __getitem__(self, key):
7474
deepseek_vl_v2="DeepseekVLV2Config",
7575
deepseek_v3="DeepseekV3Config",
7676
deepseek_v32="DeepseekV3Config",
77+
flex_olmo="FlexOlmoConfig",
7778
kimi_vl="KimiVLConfig",
7879
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
7980
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
1818
# `FalconConfig` class from the official HuggingFace transformers library.
1919
from vllm.transformers_utils.configs.falcon import RWConfig
20+
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
2021
from vllm.transformers_utils.configs.jais import JAISConfig
2122
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
2223
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
@@ -45,6 +46,7 @@
4546
"DeepseekV3Config",
4647
"DotsOCRConfig",
4748
"EAGLEConfig",
49+
"FlexOlmoConfig",
4850
"RWConfig",
4951
"JAISConfig",
5052
"Lfm2MoeConfig",

0 commit comments

Comments
 (0)