77
88constexpr uint64_t THREADS_PER_EXPERT = 512 ;
99
10- __global__ void compute_problem_sizes (const uint32_t * __restrict__ topk_ids,
10+ __global__ void compute_problem_sizes (const int32_t * __restrict__ topk_ids,
1111 int32_t * problem_sizes1,
1212 int32_t * problem_sizes2,
1313 int32_t * atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
6262 }
6363}
6464
65- __global__ void compute_arg_sorts (const uint32_t * __restrict__ topk_ids,
65+ __global__ void compute_arg_sorts (const int32_t * __restrict__ topk_ids,
6666 const int32_t * __restrict__ expert_offsets,
6767 int32_t * input_permutation,
6868 int32_t * output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
103103
104104 int num_threads = min (THREADS_PER_EXPERT, topk_ids.numel ());
105105 compute_problem_sizes<<<num_experts, num_threads, 0 , stream>>> (
106- static_cast <const uint32_t *>(topk_ids.data_ptr ()),
106+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
107107 static_cast <int32_t *>(problem_sizes1.data_ptr ()),
108108 static_cast <int32_t *>(problem_sizes2.data_ptr ()),
109109 static_cast <int32_t *>(atomic_buffer.data_ptr ()), topk_ids.numel (), n, k);
@@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
120120 static_cast <int32_t *>(atomic_buffer.data_ptr ()), num_experts);
121121 }
122122 compute_arg_sorts<<<num_experts, num_threads, 0 , stream>>> (
123- static_cast <const uint32_t *>(topk_ids.data_ptr ()),
123+ static_cast <const int32_t *>(topk_ids.data_ptr ()),
124124 static_cast <const int32_t *>(expert_offsets.data_ptr ()),
125125 static_cast <int32_t *>(input_permutation.data_ptr ()),
126126 static_cast <int32_t *>(output_permutation.data_ptr ()),
0 commit comments