@@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
12971297 qintermediate_cache2 = intermediate_cache2
12981298 a2q_scale = a2_scale
12991299
1300- invoke_fused_moe_kernel (
1301- qintermediate_cache2 ,
1302- w2 ,
1303- intermediate_cache3 ,
1304- a2q_scale ,
1305- w2_scale ,
1306- w2_zp ,
1307- curr_topk_weights ,
1308- sorted_token_ids ,
1309- expert_ids ,
1310- num_tokens_post_padded ,
1311- False , #True,
1312- 1 ,
1313- config ,
1314- compute_type = compute_type ,
1315- use_fp8_w8a8 = use_fp8_w8a8 ,
1316- use_int8_w8a16 = use_int8_w8a16 ,
1317- use_int4_w4a16 = use_int4_w4a16 ,
1318- block_shape = block_shape )
1319-
1320- if True :
1321- intermediate_cache3 = intermediate_cache3 .view (- 1 , top_k_num , K )
1322- intermediate_cache3 .mul_ (
1323- curr_topk_weights .view (tokens_in_chunk , - 1 , 1 ))
1300+ invoke_fused_moe_kernel (qintermediate_cache2 ,
1301+ w2 ,
1302+ intermediate_cache3 ,
1303+ a2q_scale ,
1304+ w2_scale ,
1305+ w2_zp ,
1306+ curr_topk_weights ,
1307+ sorted_token_ids ,
1308+ expert_ids ,
1309+ num_tokens_post_padded ,
1310+ True ,
1311+ 1 ,
1312+ config ,
1313+ compute_type = compute_type ,
1314+ use_fp8_w8a8 = use_fp8_w8a8 ,
1315+ use_int8_w8a16 = use_int8_w8a16 ,
1316+ use_int4_w4a16 = use_int4_w4a16 ,
1317+ block_shape = block_shape )
13241318
13251319 ops .moe_sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
13261320 out_hidden_states [begin_chunk_idx :end_chunk_idx ])
0 commit comments