Skip to content

Commit 5b3c09e

Browse files
committed
Make OlmoE suitable for inheritance by FlexOlmo
Signed-off-by: Shane A <shanea@allenai.org>
1 parent 8942c65 commit 5b3c09e

File tree

1 file changed

+42
-46
lines changed

1 file changed

+42
-46
lines changed

vllm/model_executor/models/olmoe.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
from collections.abc import Iterable
1717
from functools import partial
1818
from itertools import islice
19-
from typing import Any, Optional, Union
19+
from typing import Optional, Union
2020

2121
import torch
2222
from torch import nn
23-
from transformers import OlmoeConfig
2423

2524
from vllm.attention import Attention
2625
from vllm.compilation.decorators import support_torch_compile
27-
from vllm.config import CacheConfig, VllmConfig
26+
from vllm.config import VllmConfig
2827
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
2928
get_tensor_model_parallel_world_size,
3029
tensor_model_parallel_all_gather)
@@ -103,20 +102,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103102

104103
class OlmoeAttention(nn.Module):
105104

106-
def __init__(
107-
self,
108-
hidden_size: int,
109-
num_heads: int,
110-
num_kv_heads: int,
111-
rope_theta: float = 10000,
112-
rope_scaling: Optional[dict[str, Any]] = None,
113-
max_position_embeddings: int = 4096,
114-
cache_config: Optional[CacheConfig] = None,
115-
quant_config: Optional[QuantizationConfig] = None,
116-
prefix: str = "",
117-
) -> None:
105+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
118106
super().__init__()
119-
self.hidden_size = hidden_size
107+
108+
config = vllm_config.model_config.hf_config
109+
cache_config = vllm_config.cache_config
110+
quant_config = vllm_config.quant_config
111+
112+
self.hidden_size = config.hidden_size
113+
rope_theta = getattr(config, "rope_theta", 10000)
114+
rope_scaling = getattr(config, "rope_scaling", None)
115+
max_position_embeddings = getattr(config, "max_position_embeddings",
116+
4096)
117+
118+
num_heads = config.num_attention_heads
119+
num_kv_heads = config.num_key_value_heads
120+
120121
tp_size = get_tensor_model_parallel_world_size()
121122
self.total_num_heads = num_heads
122123
assert self.total_num_heads % tp_size == 0
@@ -131,15 +132,15 @@ def __init__(
131132
# the KV heads across multiple tensor parallel GPUs.
132133
assert tp_size % self.total_num_kv_heads == 0
133134
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134-
self.head_dim = hidden_size // self.total_num_heads
135+
self.head_dim = self.hidden_size // self.total_num_heads
135136
self.q_size = self.num_heads * self.head_dim
136137
self.kv_size = self.num_kv_heads * self.head_dim
137138
self.scaling = self.head_dim**-0.5
138139
self.rope_theta = rope_theta
139140
self.max_position_embeddings = max_position_embeddings
140141

141142
self.qkv_proj = QKVParallelLinear(
142-
hidden_size,
143+
self.hidden_size,
143144
self.head_dim,
144145
self.total_num_heads,
145146
self.total_num_kv_heads,
@@ -153,7 +154,7 @@ def __init__(
153154
eps=1e-5)
154155
self.o_proj = RowParallelLinear(
155156
self.total_num_heads * self.head_dim,
156-
hidden_size,
157+
self.hidden_size,
157158
bias=False,
158159
quant_config=quant_config,
159160
)
@@ -204,29 +205,15 @@ def forward(
204205

205206
class OlmoeDecoderLayer(nn.Module):
206207

207-
def __init__(
208-
self,
209-
config: OlmoeConfig,
210-
cache_config: Optional[CacheConfig] = None,
211-
quant_config: Optional[QuantizationConfig] = None,
212-
prefix: str = "",
213-
) -> None:
208+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
214209
super().__init__()
210+
config = vllm_config.model_config.hf_config
211+
quant_config = vllm_config.quant_config
212+
215213
self.hidden_size = config.hidden_size
216-
rope_theta = getattr(config, "rope_theta", 10000)
217-
rope_scaling = getattr(config, "rope_scaling", None)
218-
max_position_embeddings = getattr(config, "max_position_embeddings",
219-
4096)
220214

221215
self.self_attn = OlmoeAttention(
222-
hidden_size=self.hidden_size,
223-
num_heads=config.num_attention_heads,
224-
num_kv_heads=config.num_key_value_heads,
225-
rope_theta=rope_theta,
226-
rope_scaling=rope_scaling,
227-
max_position_embeddings=max_position_embeddings,
228-
cache_config=cache_config,
229-
quant_config=quant_config,
216+
vllm_config=vllm_config,
230217
prefix=f"{prefix}.self_attn",
231218
)
232219

@@ -270,12 +257,14 @@ def forward(
270257
@support_torch_compile
271258
class OlmoeModel(nn.Module):
272259

273-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
260+
def __init__(self,
261+
*,
262+
vllm_config: VllmConfig,
263+
prefix: str = "",
264+
layer_type: type[nn.Module] = OlmoeDecoderLayer):
274265
super().__init__()
275266

276267
config = vllm_config.model_config.hf_config
277-
cache_config = vllm_config.cache_config
278-
quant_config = vllm_config.quant_config
279268

280269
self.vocab_size = config.vocab_size
281270
self.config = config
@@ -285,8 +274,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
285274
)
286275
self.start_layer, self.end_layer, self.layers = make_layers(
287276
config.num_hidden_layers,
288-
lambda prefix: OlmoeDecoderLayer(
289-
config, cache_config, quant_config, prefix=prefix),
277+
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
290278
prefix=f"{prefix}.layers")
291279
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
292280

@@ -328,7 +316,10 @@ def forward(
328316
"residual": residual
329317
})
330318

331-
hidden_states, _ = self.norm(hidden_states, residual)
319+
if residual is not None:
320+
hidden_states, _ = self.norm(hidden_states, residual)
321+
else:
322+
hidden_states = self.norm(hidden_states)
332323
return hidden_states
333324

334325
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -440,14 +431,19 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
440431
],
441432
}
442433

443-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
434+
def __init__(self,
435+
*,
436+
vllm_config: VllmConfig,
437+
prefix: str = "",
438+
layer_type: type[nn.Module] = OlmoeDecoderLayer):
444439
super().__init__()
445440
config = vllm_config.model_config.hf_config
446441
quant_config = vllm_config.quant_config
447442
self.config = config
448443
self.quant_config = quant_config
449444
self.model = OlmoeModel(vllm_config=vllm_config,
450-
prefix=maybe_prefix(prefix, "model"))
445+
prefix=maybe_prefix(prefix, "model"),
446+
layer_type=layer_type)
451447
self.lm_head = ParallelLMHead(config.vocab_size,
452448
config.hidden_size,
453449
quant_config=quant_config,

0 commit comments

Comments
 (0)