Skip to content

Commit 2de546c

Browse files
committed
rebase and fix
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent c403a57 commit 2de546c

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512;
99
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
1010
constexpr int SWAP_AB_THRESHOLD = 64;
1111

12+
template <bool SWAP_AB>
1213
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
1314
int32_t* problem_sizes1,
1415
int32_t* problem_sizes2,
@@ -117,11 +118,21 @@ void get_cutlass_moe_mm_data_caller(
117118
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
118119

119120
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
120-
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
121-
static_cast<const int32_t*>(topk_ids.data_ptr()),
122-
static_cast<int32_t*>(problem_sizes1.data_ptr()),
123-
static_cast<int32_t*>(problem_sizes2.data_ptr()),
124-
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
121+
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
122+
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
123+
static_cast<const int32_t*>(topk_ids.data_ptr()),
124+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
125+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
126+
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
127+
k);
128+
} else {
129+
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
130+
static_cast<const int32_t*>(topk_ids.data_ptr()),
131+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
132+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
133+
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
134+
k);
135+
}
125136
if (blockscale_offsets.has_value()) {
126137
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
127138
static_cast<const int32_t*>(problem_sizes1.data_ptr()),

0 commit comments

Comments
 (0)