77from typing import Any , Callable , Optional
88
99import torch
10+ import torch .nn .functional as F
1011
1112import vllm .envs as envs
1213import vllm .model_executor .layers .fused_moe .modular_kernel as mk
@@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10011002 topk_weights : torch .Tensor ,
10021003 topk_ids : torch .Tensor ,
10031004 activation : str = "silu" ,
1005+ is_act_and_mul : bool = True ,
10041006 apply_router_weight_on_input : bool = False ,
10051007 use_fp8_w8a8 : bool = False ,
10061008 use_int8_w8a8 : bool = False ,
@@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10181020 a2_scale : Optional [torch .Tensor ] = None ,
10191021 block_shape : Optional [list [int ]] = None ) -> None :
10201022 fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids , True ,
1021- activation , apply_router_weight_on_input , use_fp8_w8a8 ,
1023+ activation , is_act_and_mul ,
1024+ apply_router_weight_on_input , use_fp8_w8a8 ,
10221025 use_int8_w8a8 , use_int8_w8a16 , use_int4_w4a16 ,
10231026 use_mxfp4_w4a4 , per_channel_quant , global_num_experts ,
10241027 expert_map , w1_scale , w2_scale , w1_zp , w2_zp , a1_scale ,
@@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
10321035 topk_weights : torch .Tensor ,
10331036 topk_ids : torch .Tensor ,
10341037 activation : str = "silu" ,
1038+ is_act_and_mul : bool = True ,
10351039 apply_router_weight_on_input : bool = False ,
10361040 use_fp8_w8a8 : bool = False ,
10371041 use_int8_w8a8 : bool = False ,
@@ -1167,6 +1171,7 @@ def outplace_fused_experts(
11671171 topk_weights : torch .Tensor ,
11681172 topk_ids : torch .Tensor ,
11691173 activation : str = "silu" ,
1174+ is_act_and_mul : bool = True ,
11701175 apply_router_weight_on_input : bool = False ,
11711176 use_fp8_w8a8 : bool = False ,
11721177 use_int8_w8a8 : bool = False ,
@@ -1183,13 +1188,12 @@ def outplace_fused_experts(
11831188 a1_scale : Optional [torch .Tensor ] = None ,
11841189 a2_scale : Optional [torch .Tensor ] = None ,
11851190 block_shape : Optional [list [int ]] = None ) -> torch .Tensor :
1186- return fused_experts_impl (hidden_states , w1 , w2 , topk_weights , topk_ids ,
1187- False , activation , apply_router_weight_on_input ,
1188- use_fp8_w8a8 , use_int8_w8a8 , use_int8_w8a16 ,
1189- use_int4_w4a16 , use_mxfp4_w4a4 ,
1190- per_channel_quant , global_num_experts ,
1191- expert_map , w1_scale , w2_scale , w1_zp , w2_zp ,
1192- a1_scale , a2_scale , block_shape )
1191+ return fused_experts_impl (
1192+ hidden_states , w1 , w2 , topk_weights , topk_ids , False , activation ,
1193+ is_act_and_mul , apply_router_weight_on_input , use_fp8_w8a8 ,
1194+ use_int8_w8a8 , use_int8_w8a16 , use_int4_w4a16 , use_mxfp4_w4a4 ,
1195+ per_channel_quant , global_num_experts , expert_map , w1_scale , w2_scale ,
1196+ w1_zp , w2_zp , a1_scale , a2_scale , block_shape )
11931197
11941198
11951199def outplace_fused_experts_fake (
@@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
11991203 topk_weights : torch .Tensor ,
12001204 topk_ids : torch .Tensor ,
12011205 activation : str = "silu" ,
1206+ is_act_and_mul : bool = True ,
12021207 use_fp8_w8a8 : bool = False ,
12031208 use_int8_w8a8 : bool = False ,
12041209 use_int8_w8a16 : bool = False ,
@@ -1253,6 +1258,7 @@ def fused_experts(
12531258 topk_ids : torch .Tensor ,
12541259 inplace : bool = False ,
12551260 activation : str = "silu" ,
1261+ is_act_and_mul : bool = True ,
12561262 apply_router_weight_on_input : bool = False ,
12571263 use_fp8_w8a8 : bool = False ,
12581264 use_int8_w8a8 : bool = False ,
@@ -1283,6 +1289,8 @@ def fused_experts(
12831289 or is_blackwell_deep_gemm_used ())
12841290 if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm ):
12851291 assert apply_router_weight_on_input is False
1292+ assert is_act_and_mul , (
1293+ "DeepGemm only supports is_act_and_mul=True for now." )
12861294 return deep_gemm_moe_fp8 (
12871295 hidden_states = hidden_states ,
12881296 w1 = w1 ,
@@ -1319,6 +1327,7 @@ def fused_experts(
13191327 topk_weights = topk_weights ,
13201328 topk_ids = topk_ids ,
13211329 activation = activation ,
1330+ is_act_and_mul = is_act_and_mul ,
13221331 apply_router_weight_on_input = apply_router_weight_on_input ,
13231332 use_fp8_w8a8 = use_fp8_w8a8 ,
13241333 use_int8_w8a8 = use_int8_w8a8 ,
@@ -1345,6 +1354,7 @@ def fused_experts_impl(
13451354 topk_ids : torch .Tensor ,
13461355 inplace : bool = False ,
13471356 activation : str = "silu" ,
1357+ is_act_and_mul : bool = True ,
13481358 apply_router_weight_on_input : bool = False ,
13491359 use_fp8_w8a8 : bool = False ,
13501360 use_int8_w8a8 : bool = False ,
@@ -1503,14 +1513,21 @@ def fused_experts_impl(
15031513 per_channel_quant = per_channel_quant ,
15041514 block_shape = block_shape )
15051515
1506- if activation == "silu" :
1516+ # Activation function with multiplication
1517+ if activation == "silu" and is_act_and_mul :
15071518 torch .ops ._C .silu_and_mul (intermediate_cache2 ,
15081519 intermediate_cache1 .view (- 1 , N ))
1509- elif activation == "gelu" :
1520+ elif activation == "gelu" and is_act_and_mul :
15101521 torch .ops ._C .gelu_and_mul (intermediate_cache2 ,
15111522 intermediate_cache1 .view (- 1 , N ))
1523+ # Activation function without multiplication
1524+ elif activation == "silu" :
1525+ intermediate_cache2 = F .silu (intermediate_cache1 .view (- 1 , N ))
1526+ elif activation == "gelu" :
1527+ intermediate_cache2 = F .gelu (intermediate_cache1 .view (- 1 , N ))
15121528 else :
1513- raise ValueError (f"Unsupported FusedMoe activation: { activation } " )
1529+ raise ValueError (f"Unsupported FusedMoe activation: { activation } , "
1530+ f"with is_act_and_mul={ is_act_and_mul } ." )
15141531
15151532 qintermediate_cache2 , a2q_scale = moe_kernel_quantize_input (
15161533 A = intermediate_cache2 ,
@@ -1555,6 +1572,7 @@ def fused_moe(
15551572 renormalize : bool ,
15561573 inplace : bool = False ,
15571574 activation : str = "silu" ,
1575+ is_act_and_mul : bool = True ,
15581576 use_grouped_topk : bool = False ,
15591577 num_expert_group : Optional [int ] = None ,
15601578 topk_group : Optional [int ] = None ,
@@ -1591,6 +1609,9 @@ def fused_moe(
15911609 Defaults to False.
15921610 - activation (str): The activation function to apply after the first
15931611 MoE layer.
1612+ - is_act_and_mul (bool): If True, use activation-and-mul function for
1613+ activation (self-gated activation), otherwise use activation function
1614+ for activation (ungated activation).
15941615 - num_expert_group: Optional[int]: additional parameter for grouped_topk
15951616 - topk_group: Optional[int]: additional parameter for grouped_topk
15961617 - use_grouped_topk: If True, use grouped_topk instead of fused_topk
@@ -1627,6 +1648,9 @@ def fused_moe(
16271648 Returns:
16281649 - torch.Tensor: The output tensor after applying the MoE layer.
16291650 """
1651+ if not is_act_and_mul :
1652+ assert inplace is False , (
1653+ "is_act_and_mul=False is not supported with inplace=True" )
16301654
16311655 if use_grouped_topk :
16321656 assert num_expert_group is not None and topk_group is not None
@@ -1647,6 +1671,7 @@ def fused_moe(
16471671 topk_ids ,
16481672 inplace = inplace ,
16491673 activation = activation ,
1674+ is_act_and_mul = is_act_and_mul ,
16501675 use_fp8_w8a8 = use_fp8_w8a8 ,
16511676 use_int8_w8a8 = use_int8_w8a8 ,
16521677 use_int8_w8a16 = use_int8_w8a16 ,
0 commit comments