Skip to content

Commit 691fafb

Browse files
Isotr0pyjinzhen-lin
authored andcommitted
[Model]: Fused MoE for nomic-embed-text-v2-moe (vllm-project#18321)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
1 parent 95276b7 commit 691fafb

File tree

2 files changed

+142
-113
lines changed

2 files changed

+142
-113
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Optional
88

99
import torch
10+
import torch.nn.functional as F
1011

1112
import vllm.envs as envs
1213
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10011002
topk_weights: torch.Tensor,
10021003
topk_ids: torch.Tensor,
10031004
activation: str = "silu",
1005+
is_act_and_mul: bool = True,
10041006
apply_router_weight_on_input: bool = False,
10051007
use_fp8_w8a8: bool = False,
10061008
use_int8_w8a8: bool = False,
@@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10181020
a2_scale: Optional[torch.Tensor] = None,
10191021
block_shape: Optional[list[int]] = None) -> None:
10201022
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
1021-
activation, apply_router_weight_on_input, use_fp8_w8a8,
1023+
activation, is_act_and_mul,
1024+
apply_router_weight_on_input, use_fp8_w8a8,
10221025
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
10231026
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
10241027
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
@@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
10321035
topk_weights: torch.Tensor,
10331036
topk_ids: torch.Tensor,
10341037
activation: str = "silu",
1038+
is_act_and_mul: bool = True,
10351039
apply_router_weight_on_input: bool = False,
10361040
use_fp8_w8a8: bool = False,
10371041
use_int8_w8a8: bool = False,
@@ -1167,6 +1171,7 @@ def outplace_fused_experts(
11671171
topk_weights: torch.Tensor,
11681172
topk_ids: torch.Tensor,
11691173
activation: str = "silu",
1174+
is_act_and_mul: bool = True,
11701175
apply_router_weight_on_input: bool = False,
11711176
use_fp8_w8a8: bool = False,
11721177
use_int8_w8a8: bool = False,
@@ -1183,13 +1188,12 @@ def outplace_fused_experts(
11831188
a1_scale: Optional[torch.Tensor] = None,
11841189
a2_scale: Optional[torch.Tensor] = None,
11851190
block_shape: Optional[list[int]] = None) -> torch.Tensor:
1186-
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
1187-
False, activation, apply_router_weight_on_input,
1188-
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
1189-
use_int4_w4a16, use_mxfp4_w4a4,
1190-
per_channel_quant, global_num_experts,
1191-
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
1192-
a1_scale, a2_scale, block_shape)
1191+
return fused_experts_impl(
1192+
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
1193+
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
1194+
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
1195+
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
1196+
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
11931197

11941198

11951199
def outplace_fused_experts_fake(
@@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
11991203
topk_weights: torch.Tensor,
12001204
topk_ids: torch.Tensor,
12011205
activation: str = "silu",
1206+
is_act_and_mul: bool = True,
12021207
use_fp8_w8a8: bool = False,
12031208
use_int8_w8a8: bool = False,
12041209
use_int8_w8a16: bool = False,
@@ -1253,6 +1258,7 @@ def fused_experts(
12531258
topk_ids: torch.Tensor,
12541259
inplace: bool = False,
12551260
activation: str = "silu",
1261+
is_act_and_mul: bool = True,
12561262
apply_router_weight_on_input: bool = False,
12571263
use_fp8_w8a8: bool = False,
12581264
use_int8_w8a8: bool = False,
@@ -1283,6 +1289,8 @@ def fused_experts(
12831289
or is_blackwell_deep_gemm_used())
12841290
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
12851291
assert apply_router_weight_on_input is False
1292+
assert is_act_and_mul, (
1293+
"DeepGemm only supports is_act_and_mul=True for now.")
12861294
return deep_gemm_moe_fp8(
12871295
hidden_states=hidden_states,
12881296
w1=w1,
@@ -1319,6 +1327,7 @@ def fused_experts(
13191327
topk_weights=topk_weights,
13201328
topk_ids=topk_ids,
13211329
activation=activation,
1330+
is_act_and_mul=is_act_and_mul,
13221331
apply_router_weight_on_input=apply_router_weight_on_input,
13231332
use_fp8_w8a8=use_fp8_w8a8,
13241333
use_int8_w8a8=use_int8_w8a8,
@@ -1345,6 +1354,7 @@ def fused_experts_impl(
13451354
topk_ids: torch.Tensor,
13461355
inplace: bool = False,
13471356
activation: str = "silu",
1357+
is_act_and_mul: bool = True,
13481358
apply_router_weight_on_input: bool = False,
13491359
use_fp8_w8a8: bool = False,
13501360
use_int8_w8a8: bool = False,
@@ -1503,14 +1513,21 @@ def fused_experts_impl(
15031513
per_channel_quant=per_channel_quant,
15041514
block_shape=block_shape)
15051515

1506-
if activation == "silu":
1516+
# Activation function with multiplication
1517+
if activation == "silu" and is_act_and_mul:
15071518
torch.ops._C.silu_and_mul(intermediate_cache2,
15081519
intermediate_cache1.view(-1, N))
1509-
elif activation == "gelu":
1520+
elif activation == "gelu" and is_act_and_mul:
15101521
torch.ops._C.gelu_and_mul(intermediate_cache2,
15111522
intermediate_cache1.view(-1, N))
1523+
# Activation function without multiplication
1524+
elif activation == "silu":
1525+
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
1526+
elif activation == "gelu":
1527+
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
15121528
else:
1513-
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
1529+
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
1530+
f"with is_act_and_mul={is_act_and_mul}.")
15141531

15151532
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
15161533
A=intermediate_cache2,
@@ -1555,6 +1572,7 @@ def fused_moe(
15551572
renormalize: bool,
15561573
inplace: bool = False,
15571574
activation: str = "silu",
1575+
is_act_and_mul: bool = True,
15581576
use_grouped_topk: bool = False,
15591577
num_expert_group: Optional[int] = None,
15601578
topk_group: Optional[int] = None,
@@ -1591,6 +1609,9 @@ def fused_moe(
15911609
Defaults to False.
15921610
- activation (str): The activation function to apply after the first
15931611
MoE layer.
1612+
- is_act_and_mul (bool): If True, use activation-and-mul function for
1613+
activation (self-gated activation), otherwise use activation function
1614+
for activation (ungated activation).
15941615
- num_expert_group: Optional[int]: additional parameter for grouped_topk
15951616
- topk_group: Optional[int]: additional parameter for grouped_topk
15961617
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
@@ -1627,6 +1648,9 @@ def fused_moe(
16271648
Returns:
16281649
- torch.Tensor: The output tensor after applying the MoE layer.
16291650
"""
1651+
if not is_act_and_mul:
1652+
assert inplace is False, (
1653+
"is_act_and_mul=False is not supported with inplace=True")
16301654

16311655
if use_grouped_topk:
16321656
assert num_expert_group is not None and topk_group is not None
@@ -1647,6 +1671,7 @@ def fused_moe(
16471671
topk_ids,
16481672
inplace=inplace,
16491673
activation=activation,
1674+
is_act_and_mul=is_act_and_mul,
16501675
use_fp8_w8a8=use_fp8_w8a8,
16511676
use_int8_w8a8=use_int8_w8a8,
16521677
use_int8_w8a16=use_int8_w8a16,

0 commit comments

Comments
 (0)