Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ th {
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def check_available_online(
trust_remote_code=True),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
Copy link
Member

@DarkLight1337 DarkLight1337 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add is_available_online=False (if the repo isn't available yet) and/or min_transformers_version (if the model isn't supported by the current version)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9542485 The problem was that I put olmo3 instead of olmo2 in the registry, which I have fixed in this commit. The above two solutions don't apply since the repo is public and this implementation is intended to work even before the transformers implementation is released.

"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m",
{"1b": "facebook/opt-iml-max-1.3b"}),
Expand Down
42 changes: 28 additions & 14 deletions vllm/model_executor/models/olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader, is_pp_missing_parameter,
AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Olmo3Config


class Olmo2Attention(nn.Module):
Expand All @@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
assert isinstance(self.config, Olmo2Config)
assert isinstance(self.config, (Olmo2Config, Olmo3Config))

hidden_size = self.config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -111,22 +112,35 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.q_norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)

# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, # type: ignore
)
self.scaling = self.head_dim**-0.5

layer_idx = extract_layer_index(prefix)
sliding_window = None
if ((layer_types := getattr(self.config, "layer_types", None))
is not None and layer_types[layer_idx] == "sliding_attention"):
sliding_window = self.config.sliding_window

self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=prefix,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
)

# Rotary embeddings. Rope scaling is only applied on full attention
# layers.
self.rope_scaling = (self.config.rope_scaling
if sliding_window is None else None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, # type: ignore
rope_scaling=self.rope_scaling,
)

# Attention output projection.
Expand Down Expand Up @@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
assert isinstance(config, Olmo2Config)
assert isinstance(config, (Olmo2Config, Olmo3Config))
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size

Expand Down Expand Up @@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
assert isinstance(config, Olmo2Config)
assert isinstance(config, (Olmo2Config, Olmo3Config))
# Attention block.
self.self_attn = Olmo2Attention(vllm_config=vllm_config,
prefix=f"{prefix}.self_attn")
Expand Down Expand Up @@ -261,7 +275,7 @@ class Olmo2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
assert isinstance(self.config, Olmo2Config)
assert isinstance(self.config, (Olmo2Config, Olmo3Config))

self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
Expand Down Expand Up @@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
assert isinstance(config, Olmo2Config)
assert isinstance(config, (Olmo2Config, Olmo3Config))
self.config = config
self.model = Olmo2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __getitem__(self, key):
eagle="EAGLEConfig",
speculators="SpeculatorsConfig",
nemotron="NemotronConfig",
olmo3="Olmo3Config",
ovis="OvisConfig",
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.olmo3 import Olmo3Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
Expand All @@ -45,6 +46,7 @@
"NemotronConfig",
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"Olmo3Config",
"OvisConfig",
"SpeculatorsConfig",
"UltravoxConfig",
Expand Down
80 changes: 80 additions & 0 deletions vllm/transformers_utils/configs/olmo3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from transformers.configuration_utils import PretrainedConfig


class Olmo3Config(PretrainedConfig):

model_type = "olmo3"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=1e-5,
sliding_window=4096,
layer_types=None,
**kwargs,
):
# This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM
# in vLLM.
if "architectures" not in kwargs:
kwargs["architectures"] = ["Olmo2ForCausalLM"]
elif "Olmo3ForCausalLM" in kwargs["architectures"]:
kwargs["architectures"].remove("Olmo3ForCausalLM")
kwargs["architectures"].append("Olmo2ForCausalLM")

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

self.rms_norm_eps = rms_norm_eps

self.sliding_window = sliding_window
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention"
for i in range(self.num_hidden_layers)
]