diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index fda981f4c800..c4488c0c6ff3 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -9,8 +9,11 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - deep_gemm_moe_fp8, fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - if (N <= 512): + if N <= 512: pytest.skip("Skipping N <= 512 until performance issues solved.") vllm_config = VllmConfig() diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 1652c72d86fe..3cfed6ae8538 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -4,8 +4,8 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - fused_experts, +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph( c_strides2, a1_scale=a_scale1) - print(triton_output) - print(cutlass_output) - print("*") + #print(triton_output) + #print(cutlass_output) + #print("*") torch.testing.assert_close(triton_output, cutlass_output, @@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph( graph.replay() torch.cuda.synchronize() - print(triton_output) - print(cutlass_output) - print("*") + #print(triton_output) + #print(cutlass_output) + #print("*") torch.testing.assert_close(triton_output, cutlass_output, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e096d14fc6f9..9829ccdb384f 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -35,9 +35,11 @@ def get_config() -> Optional[Dict[str, Any]]: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe_fp8, fused_experts, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ "fused_moe", diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py new file mode 100644 index 000000000000..a17afd1b357e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE kernel.""" +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - out_dtype (torch.Tensor): The output tensor type. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_q.dtype == torch.float8_e4m3fn + assert w2_q.dtype == torch.float8_e4m3fn + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1_q.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2_q.shape[2], "W2 scale shape mismatch" + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert ab_strides1.shape[0] == w1_q.shape[ + 0], "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1_q.shape[ + 0], "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2_q.shape[ + 0], "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2_q.shape[ + 0], "C Strides 2 expert number mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + topk = topk_ids.size(1) + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) + device = a_q.device + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_experts, n, + k) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, + expert_offsets[:-1], problem_sizes1, ab_strides1, + ab_strides1, c_strides1) + + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) + torch.ops._C.silu_and_mul(intermediate, c1) + + intemediate_q, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) + + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, ab_strides2, + ab_strides2, c_strides2) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py new file mode 100644 index 000000000000..353c8cc9d59f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import Optional, Tuple + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, + _fp8_quantize, + _resize_cache) +from vllm.utils import round_up + +logger = init_logger(__name__) + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +def _valid_deep_gemm(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + expert_map: Optional[torch.Tensor] = None) -> bool: + """ + Check if the given problem size is supported by the DeepGemm grouped + gemm kernel. All of M, N, K and the quantization block_shape must be + aligned by `dg.get_m_alignment_for_contiguous_layout()`. + """ + if not has_deep_gemm: + return False + + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + # Expert maps not supported yet. + if expert_map is not None: + return False + + align = dg.get_m_alignment_for_contiguous_layout() + M = hidden_states.shape[0] + _, K, N = w2.shape + + # For now, disable DeepGemm for small N until better permute/unpermute + # ops are available. + if N <= 512: + return False + + if align > M or N % align != 0 or K % align != 0: + return False + + return (hidden_states.is_contiguous() and w1.is_contiguous() + and w2.is_contiguous()) + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.shape[1] + + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.shape + K = curr_hidden.shape[1] + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) + + +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + + assert expert_map is None, "Expert maps not supported yet" + + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert a1_scale is None or a1_scale.dim( + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ + 0] == hidden_states.shape[0], "Input scale shape mismatch" + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + + assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + block_m = dg.get_m_alignment_for_contiguous_layout() + block_shape = [block_m, block_m] + + assert w1_scale is not None + assert w2_scale is not None + + # We attempt to transpose and align offline in Fp8MoEMethod, in which + # case these calls will be nops. Otherwise, they'll be performed every + # time the layer is executed. + w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() + w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + + M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.empty(M_sum * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + + workspace1 = workspace13[:M_sum * N].view(M_sum, N) + workspace2 = torch.empty((M_sum, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + workspace3 = workspace13[:M_sum * K].view(M_sum, K) + + for chunk in range(num_chunks): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + a1q_scale: Optional[torch.Tensor] = None + + qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, + a1_scale, block_shape) + + (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, + curr_topk_ids, global_num_experts, + expert_map, block_m) + + # Adjust the intermediate cache size and config for the last chunk. + # Note that in most cases we only have one chunk so the cache size + # and config are already set correctly and do not need to be adjusted. + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + curr_M = sorted_token_ids.numel() + workspace1 = _resize_cache(workspace1, (curr_M, N)) + workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) + workspace3 = _resize_cache(workspace3, (curr_M, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, + expert_ids) + + if activation == "silu": + torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, + block_shape) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) + + _moe_unpermute_and_reduce( + out_hidden_states[begin_chunk_idx:end_chunk_idx], + workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) + + return out_hidden_states diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 977447e03995..aa0bd553fc32 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" import functools -import importlib.util import json import os -from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -14,10 +12,13 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, round_up +from vllm.utils import direct_register_custom_op from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, rocm_aiter_fused_experts, @@ -25,8 +26,6 @@ logger = init_logger(__name__) -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -443,300 +442,13 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def ceil_div(a, b): - return (a + b - 1) // b - - -@triton.jit -def moe_align_block_size_stage1( - topk_ids_ptr, - tokens_cnts_ptr, - num_experts: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - - start_idx = pid * tokens_per_thread - - off_c = (pid + 1) * num_experts - - for i in range(tokens_per_thread): - if start_idx + i < numel: - idx = tl.load(topk_ids_ptr + start_idx + i) - token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) - tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) - - -@triton.jit -def moe_align_block_size_stage2( - tokens_cnts_ptr, - num_experts: tl.constexpr, -): - pid = tl.program_id(0) - - last_cnt = 0 - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) - last_cnt = last_cnt + token_cnt - tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) - - -@triton.jit -def moe_align_block_size_stage3( - total_tokens_post_pad_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, -): - last_cumsum = 0 - off_cnt = num_experts * num_experts - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) - last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size - tl.store(cumsum_ptr + i, last_cumsum) - tl.store(total_tokens_post_pad_ptr, last_cumsum) - - -@triton.jit -def moe_align_block_size_stage4( - topk_ids_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - start_idx = tl.load(cumsum_ptr + pid) - end_idx = tl.load(cumsum_ptr + pid + 1) - - for i in range(start_idx, end_idx, block_size): - tl.store(expert_ids_ptr + i // block_size, pid) - - start_idx = pid * tokens_per_thread - off_t = pid * num_experts - - for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, - numel)): - expert_id = tl.load(topk_ids_ptr + i) - token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) - rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) - tl.store(sorted_token_ids_ptr + rank_post_pad, i) - tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) - - -# Triton implementation based on: -# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 -def moe_align_block_size_triton( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - numel = topk_ids.numel() - grid = (num_experts, ) - tokens_cnts = torch.zeros((num_experts + 1, num_experts), - dtype=torch.int32, - device=topk_ids.device) - cumsum = torch.zeros((num_experts + 1, ), - dtype=torch.int32, - device=topk_ids.device) - tokens_per_thread = ceil_div(numel, num_experts) - - moe_align_block_size_stage1[grid]( - topk_ids, - tokens_cnts, - num_experts, - numel, - tokens_per_thread, - ) - moe_align_block_size_stage2[grid]( - tokens_cnts, - num_experts, - ) - moe_align_block_size_stage3[(1, )]( - num_tokens_post_pad, - tokens_cnts, - cumsum, - num_experts, - block_size, - ) - moe_align_block_size_stage4[grid]( - topk_ids, - sorted_token_ids, - expert_ids, - tokens_cnts, - cumsum, - num_experts, - block_size, - numel, - tokens_per_thread, - ) - - -def moe_align_block_size( - topk_ids: torch.Tensor, - block_size: int, - num_experts: int, - expert_map: Optional[torch.Tensor] = None, - pad_sorted_ids: bool = False -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Aligns the token distribution across experts to be compatible with block - size for matrix multiplication. - - Parameters: - - topk_ids: A tensor of shape [total_tokens, top_k] representing the - top-k expert indices for each token. - - block_size: The block size used in block matrix multiplication. - - num_experts: The total number of experts. - - expert_map: A tensor of shape [num_experts] that maps the expert index - from the global space to the local index space of the current - expert parallel shard. If the expert is not in the current expert - parallel shard, the mapping is set to -1. - - pad_sorted_ids: A flag indicating whether the sorted_token_ids length - should be padded to a multiple of block_size, - - Returns: - - sorted_token_ids: A tensor containing the sorted token indices according - to their allocated expert. - - expert_ids: A tensor indicating the assigned expert index for each block. - - num_tokens_post_padded: The total number of tokens after padding, - ensuring divisibility by block_size. - - This function pads the number of tokens that each expert needs to process - so that it is divisible by block_size. - Padding ensures that during block matrix multiplication, the dimensions - align correctly. - - Example: - Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], - block_size = 4, and num_experts = 4: - - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, - with each expert needing to process 3 tokens. - - As block_size is 4, we pad 1 token for each expert. - - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - - Then append padding tokens [12, 12, 12, 12] for each block. - - After sorting by expert index, we obtain token_ids - [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - Tokens 12 are non-existent (padding) and are ignored in - the subsequent matrix multiplication. - - The padding ensures that the total number of tokens is now divisible - by block_size for proper block matrix operations. - """ - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - if pad_sorted_ids: - max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be zeroed out to prevent index out of bounds error while - # mapping global expert ids to local expert ids in expert parallelism. - expert_ids = torch.zeros((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - if num_experts >= 224: - if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: - moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - # Currently requires num_experts=256 - ops.sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) - if expert_map is not None: - expert_ids = expert_map[expert_ids] - - return sorted_ids, expert_ids, num_tokens_post_pad - - -def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, - expert_map: Optional[torch.Tensor]) -> bool: - """ - Check if the given problem size is supported by the DeepGemm grouped - gemm kernel. All of M, N, K and the quantization block_shape must be - aligned by `dg.get_m_alignment_for_contiguous_layout()`. - """ - if not has_deep_gemm: - return False - - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - # Expert maps not supported yet. - if expert_map is not None: - return False - - align = dg.get_m_alignment_for_contiguous_layout() - M = hidden_states.shape[0] - _, K, N = w2.shape - - # For now, disable DeepGemm for small N until better permute/unpermute - # ops are available. - if N <= 512: - return False - - if align > M or N % align != 0 or K % align != 0: - return False - - return (hidden_states.is_contiguous() and w1.is_contiguous() - and w2.is_contiguous()) - - -def _fp8_quantize( - A: torch.Tensor, - A_scale: Optional[torch.Tensor], - block_shape: Optional[List[int]], -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform fp8 quantization on the inputs. If a block_shape - is provided, the output will be blocked. - """ - if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) - else: - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - return A, A_scale - - def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], B_zp: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_weights: Optional[torch.Tensor], sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -748,7 +460,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_int8_w8a16: bool, use_int4_w4a16: bool, block_shape: Optional[List[int]] = None) -> None: - assert topk_weights.stride(1) == 1 + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8: @@ -765,6 +478,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert A_scale is None assert B_scale is None + M = A.shape[0] + num_tokens = M * top_k + EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -782,7 +498,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert B_zp is None or B_zp.ndim == 3 use_moe_wna16_cuda = should_moe_wna16_use_cuda( - num_valid_tokens=topk_ids.numel(), + num_valid_tokens=num_tokens, group_size=block_shape[1], num_experts=B.shape[0], bit=4 if use_int4_w4a16 else 8) @@ -790,12 +506,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor, config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, - num_valid_tokens=topk_ids.numel(), + num_valid_tokens=num_tokens, size_k=A.shape[1], size_n=B.shape[1], num_experts=B.shape[1], group_size=block_shape[1], - real_top_k=topk_ids.shape[1], + real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) if use_moe_wna16_cuda: @@ -821,7 +537,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.shape[1], A.shape[1], EM, - topk_ids.numel(), + num_tokens, A.stride(0), A.stride(1), B.stride(0), @@ -864,7 +580,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.shape[1], B.shape[2], EM, - topk_ids.numel(), + num_tokens, A.stride(0), A.stride(1), B.stride(0), @@ -1389,6 +1105,7 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=inplace, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1419,85 +1136,6 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: - """ - A permutation routine that works on fp8 types. - """ - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) - else: - return m[idx, ...] - - -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - top_k_num: int, - block_m: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - tokens_in_chunk, _ = curr_hidden_states.shape - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) - - inv_perm: Optional[torch.Tensor] = None - - num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] - - # Permute according to sorted token ids. - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) - - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) - - -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk: int, - K: int, - topk_weight: torch.Tensor, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M = topk_weight.shape[0] - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) - - -def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: - """ - Shrink the given tensor and apply the given view to it. This is - used to resize the intermediate fused_moe caches. - """ - assert prod(v) <= x.numel() - return x.flatten()[:prod(v)].view(*v) - - def fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1629,7 +1267,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale, w1_zp, curr_topk_weights, - curr_topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, @@ -1660,28 +1297,34 @@ def fused_experts_impl(hidden_states: torch.Tensor, qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - block_shape=block_shape) + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, #True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape) + + if True: + intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) + intermediate_cache3.mul_( + curr_topk_weights.view(tokens_in_chunk, -1, 1)) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states @@ -1790,327 +1433,3 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) - - -def deep_gemm_moe_fp8( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with DeepGemm - grouped gemm. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1 (torch.Tensor): The first set of fp8 quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2 (torch.Tensor): The second set of fp8 quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - Returns: - - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - """ - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - assert expert_map is None, "Expert maps not supported yet" - - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == hidden_states.shape[0], "Input scale shape mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] - if global_num_experts == -1: - global_num_experts = E - top_k_num = topk_ids.shape[1] - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - - assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - block_m = dg.get_m_alignment_for_contiguous_layout() - block_shape = [block_m, block_m] - - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - cache13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N) - intermediate_cache2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K) - - for chunk in range(num_chunks): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - a1q_scale: Optional[torch.Tensor] = None - - qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, - a1_scale, block_shape) - - (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, - curr_topk_ids, global_num_experts, - expert_map, top_k_num, block_m) - - # Adjust the intermediate cache size and config for the last chunk. - # Note that in most cases we only have one chunk so the cache size - # and config are already set correctly and do not need to be adjusted. - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - curr_M = sorted_token_ids.numel() - intermediate_cache1 = _resize_cache(intermediate_cache1, - (curr_M, N)) - intermediate_cache2 = _resize_cache(intermediate_cache2, - (curr_M, N // 2)) - intermediate_cache3 = _resize_cache(intermediate_cache3, - (curr_M, K)) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qcurr_hidden_states, a1q_scale), (w1, w1_scale), - intermediate_cache1, expert_ids) - - if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - a2q_scale: Optional[torch.Tensor] = None - - qintermediate_cache2, a2q_scale = _fp8_quantize( - intermediate_cache2, a2_scale, block_shape) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qintermediate_cache2, a2q_scale), (w2, w2_scale), - intermediate_cache3, expert_ids) - - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, - top_k_num, K, curr_topk_weights) - - return out_hidden_states - - -#TODO make the grouped gemm kernel consistent with scaled gemm kernel -def cutlass_moe_fp8( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with CUTLASS - grouped gemm. - - Parameters: - - a (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - ab_strides1 (torch.Tensor): The input and weights strides of the first - grouped gemm. - - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - - ab_strides2 (torch.Tensor): The input and weights strides of the second - grouped gemm. - - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. - - Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. - """ - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - topk = topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) - - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) - - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, ab_strides1, - ab_strides1, c_strides1) - - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) - torch.ops._C.silu_and_mul(intermediate, c1) - - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, ab_strides2, - ab_strides2, c_strides2) - - return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py new file mode 100644 index 000000000000..07d51acf9867 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.utils import round_up + + +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, + numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts, ) + tokens_cnts = torch.zeros((num_experts + 1, num_experts), + dtype=torch.int32, + device=topk_ids.device) + cumsum = torch.zeros((num_experts + 1, ), + dtype=torch.int32, + device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1, )]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: Optional[torch.Tensor] = None, + pad_sorted_ids: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + expert_ids = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + if num_experts >= 224: + if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + # Currently requires num_experts=256 + ops.sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py new file mode 100644 index 000000000000..db31422f7275 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +from math import prod +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import cdiv + + +def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: + """ + Shrink the given tensor and apply the given view to it. This is + used to resize the intermediate fused_moe caches. + """ + assert prod(v) <= x.numel() + return x.flatten()[:prod(v)].view(*v) + + +def _fp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + block_shape: Optional[List[int]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform fp8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale + + +def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + A permutation routine that works on fp8 types. + """ + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...]