@@ -9,6 +9,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512;
99// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
1010constexpr int SWAP_AB_THRESHOLD = 64 ;
1111
12+ template <bool SWAP_AB>
1213__global__ void compute_problem_sizes (const int32_t * __restrict__ topk_ids,
1314 int32_t * problem_sizes1,
1415 int32_t * problem_sizes2,
@@ -117,11 +118,21 @@ void get_cutlass_moe_mm_data_caller(
117118 torch::Tensor atomic_buffer = torch::zeros (num_experts, options_int32);
118119
119120 int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
120- compute_problem_sizes<<<num_experts, num_threads, 0 , stream>>> (
121- static_cast <const int32_t *>(topk_ids.data_ptr ()),
122- static_cast <int32_t *>(problem_sizes1.data_ptr ()),
123- static_cast <int32_t *>(problem_sizes2.data_ptr ()),
124- static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n, k);
121+ if (topk_ids.numel () > SWAP_AB_THRESHOLD) {
122+ compute_problem_sizes<false ><<<num_experts, num_threads, 0 , stream>>> (
123+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
124+ static_cast <int32_t *>(problem_sizes1.data_ptr ()),
125+ static_cast <int32_t *>(problem_sizes2.data_ptr ()),
126+ static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
127+ k);
128+ } else {
129+ compute_problem_sizes<true ><<<num_experts, num_threads, 0 , stream>>> (
130+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
131+ static_cast <int32_t *>(problem_sizes1.data_ptr ()),
132+ static_cast <int32_t *>(problem_sizes2.data_ptr ()),
133+ static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n,
134+ k);
135+ }
125136 if (blockscale_offsets.has_value ()) {
126137 compute_expert_blockscale_offsets<<<1 , 1 , 0 , stream>>> (
127138 static_cast <const int32_t *>(problem_sizes1.data_ptr ()),
0 commit comments