1010def is_rocm_aiter_moe_enabled () -> bool :
1111 return current_platform .is_rocm () \
1212 and envs .VLLM_ROCM_USE_AITER_MOE \
13- and envs .VLLM_ROCM_USE_AITER \
14-
15-
16- def is_rocm_aiter_block_scaled_moe_enabled () -> bool :
17- return is_rocm_aiter_moe_enabled () and \
18- envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
13+ and envs .VLLM_ROCM_USE_AITER
14+
15+
16+ def rocm_aiter_asm_moe_tkw1 (hidden_states ,
17+ w1 ,
18+ w2 ,
19+ topk_weight ,
20+ topk_ids ,
21+ fc1_scale = None ,
22+ fc2_scale = None ,
23+ fc1_smooth_scale = None ,
24+ fc2_smooth_scale = None ,
25+ a16 = False ,
26+ per_tensor_quant_scale = None ,
27+ expert_mask = None ,
28+ activation_str : str = "silu" ) -> None :
29+
30+ from aiter import ActivationType
31+ from aiter .fused_moe_bf16_asm import asm_moe_tkw1
32+
33+ activation = \
34+ ActivationType .Gelu if activation_str == "gelu" else ActivationType .Silu
35+
36+ return asm_moe_tkw1 (hidden_states ,
37+ w1 ,
38+ w2 ,
39+ topk_weight ,
40+ topk_ids ,
41+ fc1_scale = fc1_scale ,
42+ fc2_scale = fc2_scale ,
43+ fc1_smooth_scale = fc1_smooth_scale ,
44+ fc2_smooth_scale = fc2_smooth_scale ,
45+ a16 = a16 ,
46+ per_tensor_quant_scale = per_tensor_quant_scale ,
47+ expert_mask = expert_mask ,
48+ activation = activation )
1949
2050
2151def rocm_aiter_fused_experts (
22- * ,
23- hidden_states : torch .Tensor ,
24- w1 : torch .Tensor ,
25- w2 : torch .Tensor ,
26- topk_weights : torch .Tensor ,
27- topk_ids : torch .Tensor ,
28- use_fp8_w8a8 : bool = False ,
29- apply_router_weight_on_input : bool = False ,
30- w1_scale : Optional [torch .Tensor ] = None ,
31- w2_scale : Optional [torch .Tensor ] = None ,
32- block_shape : Optional [List [int ]] = None ,
33- expert_mask : Optional [torch .Tensor ] = None ,
34- ** kwagrs # Ignore additional keyword arguments
52+ hidden_states : torch .Tensor ,
53+ w1 : torch .Tensor ,
54+ w2 : torch .Tensor ,
55+ topk_weights : torch .Tensor ,
56+ topk_ids : torch .Tensor ,
57+ inplace : bool = False ,
58+ activation : str = "silu" ,
59+ apply_router_weight_on_input : bool = False ,
60+ use_fp8_w8a8 : bool = False ,
61+ use_int8_w8a8 : bool = False ,
62+ use_int8_w8a16 : bool = False ,
63+ use_int4_w4a16 : bool = False ,
64+ per_channel_quant : bool = False ,
65+ global_num_experts : int = - 1 ,
66+ expert_map : Optional [torch .Tensor ] = None ,
67+ w1_scale : Optional [torch .Tensor ] = None ,
68+ w2_scale : Optional [torch .Tensor ] = None ,
69+ w1_zp : Optional [torch .Tensor ] = None ,
70+ w2_zp : Optional [torch .Tensor ] = None ,
71+ a1_scale : Optional [torch .Tensor ] = None ,
72+ a2_scale : Optional [torch .Tensor ] = None ,
73+ block_shape : Optional [List [int ]] = None ,
74+ allow_deep_gemm : bool = False ,
3575) -> torch .Tensor :
3676
3777 import aiter as rocm_aiter
@@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
4080 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
4181 per_token_group_quant_fp8 )
4282
43- if apply_router_weight_on_input :
44- assert (topk_weights .dim () == 2
45- ), "`topk_weights` should be in shape (num_tokens, topk)"
46- _ , topk = topk_weights .shape
47- assert (
48- topk == 1
49- ), "Only support topk=1 when `apply_router_weight_on_input` is True"
83+ # All AITER Fused MoE kernels are expecting the following datatypes
84+ topk_weights = topk_weights .to (torch .float32 )
85+ topk_ids = topk_ids .to (torch .int32 )
5086
51- hidden_states = hidden_states * topk_weights .to (hidden_states .dtype )
52- topk_ids = topk_ids .to (torch .int32 )
53- topk_weights = torch .ones_like (topk_weights , dtype = torch .float32 )
87+ if (block_shape is not None ) and use_fp8_w8a8 :
88+ assert not apply_router_weight_on_input , (
89+ "apply_router_weight_on_input is not supported for block scaled moe"
90+ )
5491
55- if envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8 :
5692 assert w1_scale is not None
5793 assert w2_scale is not None
5894
5995 local_E = E = w1 .shape [0 ]
60- if expert_mask is not None :
61- E = expert_mask .numel ()
96+ if expert_map is not None :
97+ E = expert_map .numel ()
6298
6399 topk = topk_ids .shape [1 ]
64100 model_dim = w1 .shape [- 1 ]
@@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
80116 E ,
81117 model_dim ,
82118 dtype ,
83- expert_mask = expert_mask )
119+ expert_mask = expert_map )
84120
85121 a1 , a1_scale = per_token_group_quant_fp8 (hidden_states , scale_blk_k )
86122 rocm_aiter .fmoe_fp8_blockscale_g1u1 (
@@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
102138 )
103139 return out_asm
104140
141+ elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8 :
142+ # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
143+ # This applies topk_weights on the GEMM output of the first FC layer
144+ # rather than the second FC.
145+ assert (topk_weights .dim () == 2
146+ ), "`topk_weights` should be in shape (num_tokens, topk)"
147+ assert topk_weights .shape [- 1 ] == 1 , (
148+ "Only support topk=1 when"
149+ " `apply_router_weight_on_input` is True" )
150+
151+ return rocm_aiter_asm_moe_tkw1 (hidden_states ,
152+ w1 ,
153+ w2 ,
154+ topk_weights ,
155+ topk_ids ,
156+ fc1_scale = w1_scale ,
157+ fc2_scale = w2_scale ,
158+ fc1_smooth_scale = None ,
159+ fc2_smooth_scale = None ,
160+ a16 = False ,
161+ per_tensor_quant_scale = None ,
162+ expert_mask = expert_map ,
163+ activation_str = activation )
164+
105165 elif use_fp8_w8a8 :
166+ assert not apply_router_weight_on_input , (
167+ "apply_router_weight_on_input is not supported for fp8_w8a8" )
106168 return rocm_aiter_asm_fmoe .asm_moe (hidden_states = hidden_states ,
107169 w1 = w1 ,
108170 w2 = w2 ,
@@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
114176 fc2_smooth_scale = None ,
115177 a16 = False )
116178
179+ if apply_router_weight_on_input :
180+ assert (topk_weights .dim () == 2
181+ ), "`topk_weights` should be in shape (num_tokens, topk)"
182+ _ , topk = topk_weights .shape
183+ assert (
184+ topk == 1
185+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
186+
187+ hidden_states = hidden_states * topk_weights .to (hidden_states .dtype )
188+ topk_ids = topk_ids .to (torch .int32 )
189+ topk_weights = torch .ones_like (topk_weights , dtype = torch .float32 )
190+
117191 return rocm_aiter .ck_moe (hidden_states = hidden_states ,
118192 w1 = w1 ,
119193 w2 = w2 ,
0 commit comments