@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
198198}
199199
200200// taken from
201- // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
201+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
202202template <typename scalar_t >
203203__global__ void sgl_moe_align_block_size_kernel (
204204 scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
205205 int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
206206 int32_t block_size, size_t numel, int32_t * cumsum) {
207207 __shared__ int32_t shared_counts[32 ][8 ];
208- __shared__ int32_t local_offsets[256 ];
209208
210209 const int warp_id = threadIdx .x / 32 ;
211- const int lane_id = threadIdx .x % 32 ;
212210 const int experts_per_warp = 8 ;
213211 const int my_expert_start = warp_id * experts_per_warp;
214212
213+ // Initialize shared_counts for this warp's experts
215214 for (int i = 0 ; i < experts_per_warp; ++i) {
216215 if (my_expert_start + i < num_experts) {
217216 shared_counts[warp_id][i] = 0 ;
218217 }
219218 }
220219
220+ __syncthreads ();
221+
221222 const size_t tokens_per_thread = CEILDIV (numel, blockDim .x );
222223 const size_t start_idx = threadIdx .x * tokens_per_thread;
223224
@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
230231
231232 __syncthreads ();
232233
234+ // Single thread computes cumulative sum and total tokens
233235 if (threadIdx .x == 0 ) {
234236 cumsum[0 ] = 0 ;
235237 for (int i = 1 ; i <= num_experts; ++i) {
@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
246248
247249 __syncthreads ();
248250
251+ // Assign expert IDs to blocks
249252 if (threadIdx .x < num_experts) {
250253 for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
251254 i += block_size) {
252255 expert_ids[i / block_size] = threadIdx .x ;
253256 }
254- local_offsets[threadIdx .x ] = cumsum[threadIdx .x ];
255257 }
258+ }
256259
257- __syncthreads ();
258-
259- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
260+ // taken from
261+ // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
262+ template <typename scalar_t >
263+ __global__ void sgl_moe_token_sort_kernel (scalar_t * __restrict__ topk_ids,
264+ int32_t * sorted_token_ids,
265+ int32_t * cumsum_buffer,
266+ size_t numel) {
267+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
268+ const size_t stride = blockDim .x * gridDim .x ;
269+
270+ for (size_t i = tid; i < numel; i += stride) {
260271 int32_t expert_id = topk_ids[i];
261- int32_t rank_post_pad = atomicAdd (&local_offsets [expert_id], 1 );
272+ int32_t rank_post_pad = atomicAdd (&cumsum_buffer [expert_id], 1 );
262273 sorted_token_ids[rank_post_pad] = i;
263274 }
264275}
@@ -377,23 +388,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
377388 torch::Tensor experts_ids,
378389 torch::Tensor num_tokens_post_pad) {
379390 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
391+ TORCH_CHECK (num_experts == 256 ,
392+ " sgl_moe_align_block_size kernel only supports deepseek v3." );
393+
380394 VLLM_DISPATCH_INTEGRAL_TYPES (
381395 topk_ids.scalar_type (), " sgl_moe_align_block_size_kernel" , [&] {
382- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383- // tensors
396+ // calc needed amount of shared mem for `cumsum` tensors
384397 auto options_int =
385398 torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
386- // torch::Tensor token_cnts_buffer =
387- // torch::empty({(num_experts + 1) * num_experts}, options_int);
388399 torch::Tensor cumsum_buffer =
389- torch::empty ({num_experts + 1 }, options_int);
400+ torch::zeros ({num_experts + 1 }, options_int);
390401
391- auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t >;
392- kernel<<<1 , 1024 , 0 , stream>>> (
402+ auto align_kernel =
403+ vllm::moe::sgl_moe_align_block_size_kernel<scalar_t >;
404+ align_kernel<<<1 , 1024 , 0 , stream>>> (
393405 topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
394406 experts_ids.data_ptr <int32_t >(),
395407 num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
396408 topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
409+
410+ const int block_threads = 256 ;
411+ const int num_blocks =
412+ (topk_ids.numel () + block_threads - 1 ) / block_threads;
413+ const int max_blocks = 65535 ;
414+ const int actual_blocks = std::min (num_blocks, max_blocks);
415+ auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t >;
416+ sort_kernel<<<actual_blocks, block_threads, 0 , stream>>> (
417+ topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
418+ cumsum_buffer.data_ptr <int32_t >(), topk_ids.numel ());
397419 });
398420}
399421
0 commit comments