@@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data(
241241 // mm to run it for.
242242 int32_t version_num = get_sm_version_num ();
243243#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
244- (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90 )
244+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 )
245245 get_cutlass_moe_mm_data_caller (topk_ids, expert_offsets, problem_sizes1,
246246 problem_sizes2, input_permutation,
247247 output_permutation, num_experts, n, k,
@@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data(
252252 false ,
253253 " No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
254254 " CUDA device capability: " ,
255- version_num, " . Required capability: 90" );
255+ version_num, " . Required capability: 90 or 100 " );
256256}
257257
258258void get_cutlass_pplx_moe_mm_data (torch::Tensor& expert_offsets,
@@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
265265 // This function currently gets compiled only if we have a valid cutlass moe
266266 // mm to run it for.
267267 int32_t version_num = get_sm_version_num ();
268- #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
268+ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
269+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
269270 get_cutlass_pplx_moe_mm_data_caller (expert_offsets, problem_sizes1,
270271 problem_sizes2, expert_num_tokens,
271272 num_local_experts, padded_m, n, k);
@@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
275276 false ,
276277 " No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
277278 " for CUDA device capability: " ,
278- version_num, " . Required capability: 90" );
279+ version_num, " . Required capability: 90 or 100 " );
279280}
280281
281282void cutlass_scaled_mm_azp (torch::Tensor& c, torch::Tensor const & a,
0 commit comments