Skip to content

Commit 864355a

Browse files
committed
update models
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 28af6b0 commit 864355a

21 files changed

+177
-105
lines changed

vllm/model_executor/models/aria.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from vllm.config import VllmConfig
1313
from vllm.distributed import get_tensor_model_parallel_rank
1414
from vllm.model_executor.layers.activation import get_act_fn
15-
from vllm.model_executor.layers.fused_moe import FusedMoE
16-
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
1715
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1816
RowParallelLinear)
1917
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2018
from vllm.model_executor.layers.quantization import QuantizationConfig
19+
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
2120
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2221
from vllm.model_executor.model_loader.weight_utils import (
2322
default_weight_loader, maybe_remap_kv_scale_name)
@@ -269,6 +268,7 @@ def __init__(
269268
hidden_size=config.hidden_size,
270269
intermediate_size=config.intermediate_size,
271270
quant_config=quant_config,
271+
reduce_results=True,
272272
prefix=f"{prefix}.experts",
273273
)
274274

@@ -287,8 +287,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287287
router_output = torch.nn.functional.linear(hidden_states,
288288
self.router_weight)
289289

290-
# NOTE: hidden_states will be modified inplace by `SharedFusedMoE`
291-
return self.experts(hidden_states, router_output)
290+
sparse_expert_output = self.experts(hidden_states, router_output)
291+
292+
if self.shared_experts is not None:
293+
return sparse_expert_output[0] + sparse_expert_output[1]
294+
else:
295+
return sparse_expert_output
292296

293297

294298
class AriaTextDecoderLayer(LlamaDecoderLayer):

vllm/model_executor/models/bailing_moe.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from vllm.compilation.decorators import support_torch_compile
3737
from vllm.config import CacheConfig, VllmConfig
3838
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
39-
get_tensor_model_parallel_world_size)
39+
get_tensor_model_parallel_world_size,
40+
tensor_model_parallel_all_reduce)
4041
from vllm.model_executor.layers.activation import SiluAndMul
4142
from vllm.model_executor.layers.fused_moe import FusedMoE
4243
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -278,41 +279,28 @@ def __init__(
278279
quant_config=quant_config,
279280
reduce_results=False,
280281
prefix=f"{prefix}.shared_experts")
281-
282-
self.experts = SharedFusedMoE(
283-
shared_experts=self.shared_experts,
284-
fused_output_scaling_factor=self.routed_scaling_factor,
285-
shared_output_scaling_factor=1.0,
286-
num_experts=self.num_experts,
287-
top_k=self.top_k,
288-
hidden_size=self.hidden_size,
289-
intermediate_size=config.moe_intermediate_size,
290-
renormalize=self.norm_expert_prob,
291-
quant_config=quant_config,
292-
prefix=f"{prefix}.experts",
293-
scoring_func=self.score_function,
294-
e_score_correction_bias=self.gate.expert_bias,
295-
num_expert_group=self.n_group,
296-
topk_group=self.topk_group,
297-
use_grouped_topk=self.use_grouped_topk,
298-
)
299282
else:
300-
self.experts = FusedMoE(
301-
num_experts=self.num_experts,
302-
top_k=self.top_k,
303-
hidden_size=self.hidden_size,
304-
intermediate_size=config.moe_intermediate_size,
305-
renormalize=self.norm_expert_prob,
306-
quant_config=quant_config,
307-
prefix=f"{prefix}.experts",
308-
scoring_func=self.score_function,
309-
e_score_correction_bias=self.gate.expert_bias,
310-
num_expert_group=self.n_group,
311-
topk_group=self.topk_group,
312-
use_grouped_topk=self.use_grouped_topk,
313-
)
314283
self.shared_experts = None
315284

285+
self.experts = SharedFusedMoE(
286+
shared_experts=self.shared_experts,
287+
fused_output_scaling_factor=self.routed_scaling_factor,
288+
shared_output_scaling_factor=1.0,
289+
num_experts=self.num_experts,
290+
top_k=self.top_k,
291+
hidden_size=self.hidden_size,
292+
intermediate_size=config.moe_intermediate_size,
293+
reduce_results=False,
294+
renormalize=self.norm_expert_prob,
295+
quant_config=quant_config,
296+
prefix=f"{prefix}.experts",
297+
scoring_func=self.score_function,
298+
e_score_correction_bias=self.gate.expert_bias,
299+
num_expert_group=self.n_group,
300+
topk_group=self.topk_group,
301+
use_grouped_topk=self.use_grouped_topk,
302+
)
303+
316304
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
317305
num_tokens, hidden_size = hidden_states.shape
318306
hidden_states = hidden_states.view(-1, hidden_size)
@@ -324,6 +312,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
324312
final_hidden_states = self.experts(hidden_states=hidden_states,
325313
router_logits=router_logits)
326314

315+
if self.shared_experts is not None:
316+
shared_output, final_hidden_states = final_hidden_states
317+
else:
318+
shared_output = None
319+
320+
final_hidden_states *= self.routed_scaling_factor
321+
322+
if shared_output is not None:
323+
final_hidden_states = final_hidden_states + shared_output
324+
325+
if self.tp_size > 1:
326+
final_hidden_states = tensor_model_parallel_all_reduce(
327+
final_hidden_states)
327328
return final_hidden_states.view(num_tokens, hidden_size)
328329

329330

vllm/model_executor/models/deepseek_v2.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.attention.backends.abstract import AttentionBackend
3737
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
3838
from vllm.compilation.decorators import support_torch_compile
39-
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, VllmConfig,
39+
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
4040
get_current_vllm_config)
4141
from vllm.distributed import (get_ep_group, get_pp_group,
4242
get_tensor_model_parallel_rank,
@@ -133,7 +133,6 @@ class DeepseekV2MoE(nn.Module):
133133
def __init__(
134134
self,
135135
config: Union[DeepseekV2Config, DeepseekV3Config],
136-
model_config: ModelConfig,
137136
parallel_config: ParallelConfig,
138137
quant_config: Optional[QuantizationConfig] = None,
139138
prefix: str = "",
@@ -184,8 +183,6 @@ def __init__(
184183

185184
if config.n_shared_experts is None:
186185
self.shared_experts = None
187-
fused_output_scaling_factor = 1.0
188-
shared_output_scaling_factor = 1.0
189186
else:
190187
intermediate_size = (config.moe_intermediate_size *
191188
config.n_shared_experts)
@@ -196,28 +193,17 @@ def __init__(
196193
hidden_act=config.hidden_act,
197194
quant_config=quant_config,
198195
is_sequence_parallel=self.is_sequence_parallel,
199-
reduce_results=False, # XXXXX
196+
reduce_results=False,
200197
prefix=f"{prefix}.shared_experts",
201198
)
202199

203-
# Fix FP16 overflow
204-
# See DeepseekV2DecoderLayer for more details.
205-
if model_config.dtype != torch.float16:
206-
fused_output_scaling_factor = self.routed_scaling_factor
207-
shared_output_scaling_factor = 1.0
208-
else:
209-
fused_output_scaling_factor = 1.0
210-
shared_output_scaling_factor = (1. /
211-
self.routed_scaling_factor)
212-
213200
self.experts = SharedFusedMoE(
214201
shared_experts=self.shared_experts,
215-
fused_output_scaling_factor=fused_output_scaling_factor,
216-
shared_output_scaling_factor=shared_output_scaling_factor,
217202
num_experts=config.n_routed_experts,
218203
top_k=config.num_experts_per_tok,
219204
hidden_size=config.hidden_size,
220205
intermediate_size=config.moe_intermediate_size,
206+
reduce_results=False,
221207
renormalize=config.norm_topk_prob,
222208
quant_config=quant_config,
223209
use_grouped_topk=True,
@@ -247,15 +233,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247233
# router_logits: (num_tokens, n_experts)
248234
router_logits, _ = self.gate(hidden_states)
249235

250-
final_hidden_states = self.experts(hidden_states=hidden_states,
251-
router_logits=router_logits)
236+
fused_moe_out = self.experts(hidden_states=hidden_states,
237+
router_logits=router_logits)
238+
239+
if self.shared_experts is not None:
240+
shared_output, final_hidden_states = fused_moe_out
241+
else:
242+
shared_output = None
243+
final_hidden_states = fused_moe_out
244+
245+
# Fix FP16 overflow
246+
# See DeepseekV2DecoderLayer for more details.
247+
if hidden_states.dtype != torch.float16:
248+
final_hidden_states *= self.routed_scaling_factor
249+
elif self.shared_experts is not None:
250+
assert shared_output is not None
251+
shared_output *= (1. / self.routed_scaling_factor)
252+
253+
if self.shared_experts is not None:
254+
assert shared_output is not None
255+
final_hidden_states += shared_output
252256

253257
if self.is_sequence_parallel:
254258
final_hidden_states = tensor_model_parallel_all_gather(
255259
final_hidden_states, 0)
256260
final_hidden_states = final_hidden_states[:num_tokens]
261+
elif self.tp_size > 1:
262+
final_hidden_states = (
263+
self.experts.maybe_all_reduce_tensor_model_parallel(
264+
final_hidden_states))
257265

258-
# TODO(bnell): why is this view needed?
259266
return final_hidden_states.view(num_tokens, hidden_dim)
260267

261268

@@ -1009,7 +1016,6 @@ def __init__(self,
10091016
and layer_idx % config.moe_layer_freq == 0):
10101017
self.mlp = DeepseekV2MoE(
10111018
config=config,
1012-
model_config=model_config,
10131019
parallel_config=parallel_config,
10141020
quant_config=quant_config,
10151021
prefix=f"{prefix}.mlp",

vllm/model_executor/models/dots1.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from vllm.attention import Attention
3636
from vllm.compilation.decorators import support_torch_compile
3737
from vllm.config import CacheConfig, ModelConfig, VllmConfig
38-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38+
from vllm.distributed import (get_pp_group,
39+
get_tensor_model_parallel_world_size,
40+
tensor_model_parallel_all_reduce)
3941
from vllm.model_executor.layers.activation import SiluAndMul
4042
from vllm.model_executor.layers.fused_moe import FusedMoE
4143
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -138,11 +140,11 @@ def __init__(
138140

139141
self.experts = SharedFusedMoE(
140142
shared_experts=self.shared_experts,
141-
fused_output_scaling_factor=self.routed_scaling_factor,
142143
num_experts=config.n_routed_experts,
143144
top_k=config.num_experts_per_tok,
144145
hidden_size=config.hidden_size,
145146
intermediate_size=config.moe_intermediate_size,
147+
reduce_results=False,
146148
renormalize=config.norm_topk_prob,
147149
quant_config=quant_config,
148150
use_grouped_topk=True,
@@ -159,9 +161,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
159161
hidden_states = hidden_states.view(-1, hidden_dim)
160162

161163
router_logits, _ = self.gate(hidden_states)
162-
final_hidden_states = self.experts(hidden_states=hidden_states,
163-
router_logits=router_logits)
164+
final_hidden_states = self.experts(
165+
hidden_states=hidden_states,
166+
router_logits=router_logits) * self.routed_scaling_factor
164167

168+
if self.shared_experts is not None:
169+
final_hidden_states = final_hidden_states[0] + final_hidden_states[
170+
1]
171+
172+
if self.tp_size > 1:
173+
final_hidden_states = tensor_model_parallel_all_reduce(
174+
final_hidden_states)
165175
return final_hidden_states.view(num_tokens, hidden_dim)
166176

167177

vllm/model_executor/models/ernie45_moe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
top_k=config.moe_k,
148148
hidden_size=config.hidden_size,
149149
intermediate_size=config.moe_intermediate_size,
150+
reduce_results=False,
150151
renormalize=True,
151152
quant_config=quant_config,
152153
prefix=f"{prefix}.experts",
@@ -162,6 +163,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162163
final_hidden_states = self.experts(hidden_states=hidden_states,
163164
router_logits=router_logits)
164165

166+
if self.has_shared_experts:
167+
final_hidden_states = final_hidden_states[0] + final_hidden_states[
168+
1]
169+
170+
if self.tp_size > 1:
171+
final_hidden_states = (
172+
self.experts.maybe_all_reduce_tensor_model_parallel(
173+
final_hidden_states))
174+
165175
return final_hidden_states.view(orig_shape)
166176

167177

vllm/model_executor/models/ernie45_vl_moe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def __init__(self,
7070
self.shared_experts = shared_experts
7171

7272
def forward(self, x):
73-
out = super().forward(x)
7473
if self.shared_experts is not None:
75-
out = out + self.shared_experts(x)
76-
return out
74+
return self.shared_experts(x) + super().forward(x)
75+
else:
76+
return super().forward(x)
7777

7878

7979
class Ernie4_5_VLMoeAttention(nn.Module):
@@ -244,6 +244,7 @@ def __init__(
244244
top_k=config.moe_k,
245245
hidden_size=config.hidden_size,
246246
intermediate_size=config.moe_intermediate_size[0],
247+
reduce_results=False,
247248
renormalize=True,
248249
quant_config=quant_config,
249250
e_score_correction_bias=self.e_score_correction_bias[0],
@@ -275,6 +276,7 @@ def __init__(
275276
top_k=config.moe_k,
276277
hidden_size=config.hidden_size,
277278
intermediate_size=config.moe_intermediate_size[1],
279+
reduce_results=False,
278280
renormalize=True,
279281
quant_config=quant_config,
280282
e_score_correction_bias=self.e_score_correction_bias[1],
@@ -337,6 +339,15 @@ def forward(
337339
final_hidden_states = self.text_experts(
338340
hidden_states=hidden_states, router_logits=text_router_logits)
339341

342+
if self.has_shared_experts:
343+
final_hidden_states = final_hidden_states[0] + final_hidden_states[
344+
1]
345+
346+
if self.tp_size > 1:
347+
final_hidden_states = (
348+
self.text_experts.maybe_all_reduce_tensor_model_parallel(
349+
final_hidden_states))
350+
340351
return final_hidden_states.view(orig_shape)
341352

342353

vllm/model_executor/models/glm4_moe.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,11 @@ def __init__(
162162

163163
self.experts = SharedFusedMoE(
164164
shared_experts=self.shared_experts,
165-
fused_output_scaling_factor=self.routed_scaling_factor,
166-
shared_output_scaling_factor=1.0,
167165
num_experts=config.n_routed_experts,
168166
top_k=config.num_experts_per_tok,
169167
hidden_size=config.hidden_size,
170168
intermediate_size=config.moe_intermediate_size,
169+
reduce_results=False,
171170
renormalize=config.norm_topk_prob,
172171
quant_config=quant_config,
173172
use_grouped_topk=True,
@@ -179,8 +178,7 @@ def __init__(
179178
routed_scaling_factor=1.0,
180179
e_score_correction_bias=self.gate.e_score_correction_bias,
181180
enable_eplb=self.enable_eplb,
182-
num_redundant_experts=self.n_redundant_experts,
183-
)
181+
num_redundant_experts=self.n_redundant_experts)
184182

185183
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
186184
num_tokens, hidden_dim = hidden_states.shape
@@ -189,9 +187,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
189187
# router_logits: (num_tokens, n_experts)
190188
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
191189

192-
final_hidden_states = self.experts(hidden_states=hidden_states,
193-
router_logits=router_logits)
190+
fused_moe_out = self.experts(hidden_states=hidden_states,
191+
router_logits=router_logits)
192+
193+
if self.shared_experts is not None:
194+
shared_output, final_hidden_states = fused_moe_out
195+
assert shared_output is not None
196+
final_hidden_states = \
197+
final_hidden_states * self.routed_scaling_factor\
198+
+ shared_output
199+
else:
200+
final_hidden_states = fused_moe_out * self.routed_scaling_factor
194201

202+
if self.tp_size > 1:
203+
final_hidden_states = (
204+
self.experts.maybe_all_reduce_tensor_model_parallel(
205+
final_hidden_states))
195206
return final_hidden_states.view(num_tokens, hidden_dim)
196207

197208

0 commit comments

Comments
 (0)