diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 8b6fe72ad743b..b1b91e997fa74 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel( } } +// temporarily adapted from +// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a +template +__global__ void sgl_moe_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* cumsum) { + __shared__ int32_t shared_counts[32][8]; + __shared__ int32_t local_offsets[256]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; + } + } + + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; + + cumsum[i] = + cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + local_offsets[threadIdx.x] = cumsum[threadIdx.x]; + } + + __syncthreads(); + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] @@ -305,6 +371,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } } +void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + // torch::Tensor token_cnts_buffer = + // torch::empty({(num_experts + 1) * num_experts}, options_int); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); + + auto kernel = vllm::moe::sgl_moe_align_block_size_kernel; + kernel<<<1, 1024, 0, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), cumsum_buffer.data_ptr()); + }); +} + void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 596cc0aa6c855..66bb5f41b7f78 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -12,3 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index f3a558c14ab93..8540633dcc8b0 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -22,6 +22,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + // temporarily adapted from + // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a + m.def( + "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); + #ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 85c1121ed6ff8..d1f1a2b2081fd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -923,6 +923,15 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, num_tokens_post_pad) +def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts, + block_size, sorted_token_ids, + experts_ids, num_tokens_post_pad) + + def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, gating_output: float) -> None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index dbb6c2ce4649e..33a0706ca98ad 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,6 +18,9 @@ logger = init_logger(__name__) +enable_moe_align_block_size_triton = bool( + int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))) + @triton.jit def fused_moe_kernel_gptq_awq( @@ -404,6 +407,144 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, + numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts, ) + tokens_cnts = torch.zeros((num_experts + 1, num_experts), + dtype=torch.int32, + device=topk_ids.device) + cumsum = torch.zeros((num_experts + 1, ), + dtype=torch.int32, + device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1, )]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -456,8 +597,28 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + if num_experts >= 224: + if enable_moe_align_block_size_triton: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad