1010def  is_rocm_aiter_moe_enabled () ->  bool :
1111    return  current_platform .is_rocm () \
1212        and  envs .VLLM_ROCM_USE_AITER_MOE  \
13-         and  envs .VLLM_ROCM_USE_AITER  \
14- 
15- 
16- def  is_rocm_aiter_block_scaled_moe_enabled () ->  bool :
17-     return  is_rocm_aiter_moe_enabled () and  \
18-         envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE 
13+         and  envs .VLLM_ROCM_USE_AITER 
14+ 
15+ 
16+ def  rocm_aiter_asm_moe_tkw1 (hidden_states ,
17+                             w1 ,
18+                             w2 ,
19+                             topk_weight ,
20+                             topk_ids ,
21+                             fc1_scale = None ,
22+                             fc2_scale = None ,
23+                             fc1_smooth_scale = None ,
24+                             fc2_smooth_scale = None ,
25+                             a16 = False ,
26+                             per_tensor_quant_scale = None ,
27+                             expert_mask = None ,
28+                             activation_str : str  =  "silu" ) ->  None :
29+ 
30+     from  aiter  import  ActivationType 
31+     from  aiter .fused_moe_bf16_asm  import  asm_moe_tkw1 
32+ 
33+     activation  =  \
34+         ActivationType .Gelu  if  activation_str  ==  "gelu"  else  ActivationType .Silu 
35+ 
36+     return  asm_moe_tkw1 (hidden_states ,
37+                         w1 ,
38+                         w2 ,
39+                         topk_weight ,
40+                         topk_ids ,
41+                         fc1_scale = fc1_scale ,
42+                         fc2_scale = fc2_scale ,
43+                         fc1_smooth_scale = fc1_smooth_scale ,
44+                         fc2_smooth_scale = fc2_smooth_scale ,
45+                         a16 = a16 ,
46+                         per_tensor_quant_scale = per_tensor_quant_scale ,
47+                         expert_mask = expert_mask ,
48+                         activation = activation )
1949
2050
2151def  rocm_aiter_fused_experts (
22-         * ,
23-         hidden_states : torch .Tensor ,
24-         w1 : torch .Tensor ,
25-         w2 : torch .Tensor ,
26-         topk_weights : torch .Tensor ,
27-         topk_ids : torch .Tensor ,
28-         use_fp8_w8a8 : bool  =  False ,
29-         apply_router_weight_on_input : bool  =  False ,
30-         w1_scale : Optional [torch .Tensor ] =  None ,
31-         w2_scale : Optional [torch .Tensor ] =  None ,
32-         block_shape : Optional [List [int ]] =  None ,
33-         expert_mask : Optional [torch .Tensor ] =  None ,
34-         ** kwagrs   # Ignore additional keyword arguments 
52+     hidden_states : torch .Tensor ,
53+     w1 : torch .Tensor ,
54+     w2 : torch .Tensor ,
55+     topk_weights : torch .Tensor ,
56+     topk_ids : torch .Tensor ,
57+     inplace : bool  =  False ,
58+     activation : str  =  "silu" ,
59+     apply_router_weight_on_input : bool  =  False ,
60+     use_fp8_w8a8 : bool  =  False ,
61+     use_int8_w8a8 : bool  =  False ,
62+     use_int8_w8a16 : bool  =  False ,
63+     use_int4_w4a16 : bool  =  False ,
64+     per_channel_quant : bool  =  False ,
65+     global_num_experts : int  =  - 1 ,
66+     expert_map : Optional [torch .Tensor ] =  None ,
67+     w1_scale : Optional [torch .Tensor ] =  None ,
68+     w2_scale : Optional [torch .Tensor ] =  None ,
69+     w1_zp : Optional [torch .Tensor ] =  None ,
70+     w2_zp : Optional [torch .Tensor ] =  None ,
71+     a1_scale : Optional [torch .Tensor ] =  None ,
72+     a2_scale : Optional [torch .Tensor ] =  None ,
73+     block_shape : Optional [List [int ]] =  None ,
74+     allow_deep_gemm : bool  =  False ,
3575) ->  torch .Tensor :
3676
3777    import  aiter  as  rocm_aiter 
@@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
4080    from  vllm .model_executor .layers .quantization .utils .fp8_utils  import  (
4181        per_token_group_quant_fp8 )
4282
43-     if  apply_router_weight_on_input :
44-         assert  (topk_weights .dim () ==  2 
45-                 ), "`topk_weights` should be in shape (num_tokens, topk)" 
46-         _ , topk  =  topk_weights .shape 
47-         assert  (
48-             topk  ==  1 
49-         ), "Only support topk=1 when `apply_router_weight_on_input` is True" 
83+     # All AITER Fused MoE kernels are expecting the following datatypes 
84+     topk_weights  =  topk_weights .to (torch .float32 )
85+     topk_ids  =  topk_ids .to (torch .int32 )
5086
51-         hidden_states  =  hidden_states  *  topk_weights .to (hidden_states .dtype )
52-         topk_ids  =  topk_ids .to (torch .int32 )
53-         topk_weights  =  torch .ones_like (topk_weights , dtype = torch .float32 )
87+     if  (block_shape  is  not   None ) and  use_fp8_w8a8 :
88+         assert  not  apply_router_weight_on_input , (
89+             "apply_router_weight_on_input is not supported for block scaled moe" 
90+         )
5491
55-     if  envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE  and  use_fp8_w8a8 :
5692        assert  w1_scale  is  not   None 
5793        assert  w2_scale  is  not   None 
5894
5995        local_E  =  E  =  w1 .shape [0 ]
60-         if  expert_mask  is  not   None :
61-             E  =  expert_mask .numel ()
96+         if  expert_map  is  not   None :
97+             E  =  expert_map .numel ()
6298
6399        topk  =  topk_ids .shape [1 ]
64100        model_dim  =  w1 .shape [- 1 ]
@@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
80116                                               E ,
81117                                               model_dim ,
82118                                               dtype ,
83-                                                expert_mask = expert_mask )
119+                                                expert_mask = expert_map )
84120
85121        a1 , a1_scale  =  per_token_group_quant_fp8 (hidden_states , scale_blk_k )
86122        rocm_aiter .fmoe_fp8_blockscale_g1u1 (
@@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
102138        )
103139        return  out_asm 
104140
141+     elif  per_channel_quant  and  apply_router_weight_on_input  and  use_fp8_w8a8 :
142+         # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` 
143+         # This applies topk_weights on the GEMM output of the first FC layer 
144+         #  rather than the second FC. 
145+         assert  (topk_weights .dim () ==  2 
146+                 ), "`topk_weights` should be in shape (num_tokens, topk)" 
147+         assert  topk_weights .shape [- 1 ] ==  1 , (
148+             "Only support topk=1 when" 
149+             " `apply_router_weight_on_input` is True" )
150+ 
151+         return  rocm_aiter_asm_moe_tkw1 (hidden_states ,
152+                                        w1 ,
153+                                        w2 ,
154+                                        topk_weights ,
155+                                        topk_ids ,
156+                                        fc1_scale = w1_scale ,
157+                                        fc2_scale = w2_scale ,
158+                                        fc1_smooth_scale = None ,
159+                                        fc2_smooth_scale = None ,
160+                                        a16 = False ,
161+                                        per_tensor_quant_scale = None ,
162+                                        expert_mask = expert_map ,
163+                                        activation_str = activation )
164+ 
105165    elif  use_fp8_w8a8 :
166+         assert  not  apply_router_weight_on_input , (
167+             "apply_router_weight_on_input is not supported for fp8_w8a8" )
106168        return  rocm_aiter_asm_fmoe .asm_moe (hidden_states = hidden_states ,
107169                                           w1 = w1 ,
108170                                           w2 = w2 ,
@@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
114176                                           fc2_smooth_scale = None ,
115177                                           a16 = False )
116178
179+     if  apply_router_weight_on_input :
180+         assert  (topk_weights .dim () ==  2 
181+                 ), "`topk_weights` should be in shape (num_tokens, topk)" 
182+         _ , topk  =  topk_weights .shape 
183+         assert  (
184+             topk  ==  1 
185+         ), "Only support topk=1 when `apply_router_weight_on_input` is True" 
186+ 
187+         hidden_states  =  hidden_states  *  topk_weights .to (hidden_states .dtype )
188+         topk_ids  =  topk_ids .to (torch .int32 )
189+         topk_weights  =  torch .ones_like (topk_weights , dtype = torch .float32 )
190+ 
117191    return  rocm_aiter .ck_moe (hidden_states = hidden_states ,
118192                             w1 = w1 ,
119193                             w2 = w2 ,
0 commit comments