diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 2a2f32b646bd0..74e5acbccea63 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,11 @@ vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 -cupy-cuda12x \ No newline at end of file +# Dependencies for pycublas-moe-groupe-gemm +gitpython +pytest +loguru +# In case of invalid url, please install from this file: +# pip install gitpython pytest loguru && pip install vllm/model_executor/layers/fused_moe/pycublas.zip +# or +# pip install gitpython pytest loguru && pip install git+https://github.com/wenxcs/pycublas.git@moe-group-gemm diff --git a/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py b/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py index 82f5337d8d1c9..8852f98f7e1e4 100644 --- a/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/ampere_fp8_fused_moe.py @@ -1,4 +1,5 @@ """Fused MoE kernel.""" + import functools import json import os @@ -8,227 +9,75 @@ import triton import triton.language as tl +from vllm.model_executor.layers.fused_moe import gather_scatter_kernel + from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.utils import is_hip - -logger = init_logger(__name__) - - -@triton.jit -def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) - - off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) - - if use_fp8: - b_scale = tl.load(b_scale_ptr + off_experts) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) +from vllm.utils import print_warning_once - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the - # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ).to(tl.float16) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) +import pycublas.trtllm_moe_grouped_gemm as moe_kernel - # todo(wenxh): there is a bug in triton 2.2/2.3 that only "=l" works, "=r" - # will result error in llvm check(low level bug). - b = tl.inline_asm_elementwise( - asm = "{ \n" - ".reg .b32 a<2>, b<2>; \n" # if input = 0xf1f2f3f4 - ".reg .b32 r0, r1; \n" - "prmt.b32 a0, 0, $1, 0x5040; \n" # a0 = 0xf300f400 - "prmt.b32 a1, 0, $1, 0x7060; \n" # a1 = 0xf100f200 - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" # b0 = a0 & 0x7fff7fff - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" # (strip sign) - "shr.b32 b0, b0, 1; \n" # b0 >>= 1 - "shr.b32 b1, b1, 1; \n" # shift into fp16 position - "add.u32 b0, b0, 0x20002000; \n" # b0.exp += 2**4-2**3 - # exponent compensate = 8 - "add.u32 b1, b1, 0x20002000; \n" # b1 += 8<<10 | 8<<10<<16 - "lop3.b32 r0, b0, 0x80008000, a0, 0xf8; \n" # out0 = b0|(0x80008000&a0) - "lop3.b32 r1, b1, 0x80008000, a1, 0xf8; \n" # (restore sign) - "mov.b64 $0, {r0, r1}; \n" - "} \n", - constraints="=l,r", - args = [b], - dtype = tl.float16, - is_pure = True, - pack = 4 - ) * b_scale.to(tl.float16) - - # We accumulate along the K dimension. - if use_fp8: - accumulator = tl.dot(a, b, acc=accumulator) - else: - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk +logger = init_logger(__name__) - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] +# Ampere FP8 kernel +from torch.utils.cpp_extension import load - if use_fp8: - accumulator = accumulator.to(compute_type) - else: - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) +print_warning_once("Loading ampere_fp8 kernel") +ampere_fp8 = load( + name="ampere_fp8", + sources=[ + os.path.join( + os.path.dirname(__file__), "csrc", "moe_align_block_size_kernels.cu" + ) + ], +) + +moe_gg_kernel_config = { + 1: (13, 21, 0.4587008017301559), + 2: (5, 11, 0.4829593604803085), + 3: (11, 4, 0.55322624117136), + 4: (5, 5, 0.6300467216968536), + 5: (5, 9, 0.6892339181900025), + 6: (5, 5, 0.7366860777139663), + 7: (17, 9, 0.7817830407619476), + 8: (5, 8, 0.8124313586950302), + 16: (5, 5, 1.0158489656448364), + 32: (4, 17, 1.0969907104969026), + 48: (5, 4, 1.1068108654022217), + 64: (17, 5, 1.1107225465774535), + 80: (4, 5, 1.1139481484889984), + 96: (16, 16, 1.1225907170772553), + 112: (16, 16, 1.1334041678905487), + 128: (17, 17, 1.137500158548355), + 144: (16, 17, 1.144709119796753), + 160: (16, 17, 1.1540889596939088), + 176: (16, 16, 1.1627110350131988), + 192: (17, 16, 1.1790643167495727), + 208: (22, 16, 1.2127846336364747), + 224: (23, 17, 1.2236697602272033), + 240: (22, 22, 1.2352307152748108), + 256: (23, 22, 1.2356915152072907), + 512: (23, 22, 1.6425676786899566), + 768: (27, 27, 1.7934028828144073), + 1024: (27, 23, 2.4730009508132933), + 1280: (22, 22, 3.02405633687973), + 1536: (27, 22, 3.2711680245399477), + 1792: (27, 26, 3.344619517326355), + 2048: (27, 26, 4.023920638561249), + 2304: (26, 22, 4.71138304233551), + 2560: (27, 27, 4.861614079475403), + 2816: (27, 27, 4.988712968826294), + 3072: (26, 27, 5.624104981422424), + 3328: (27, 26, 6.2363647985458375), + 3584: (26, 26, 6.384680962562561), + 3840: (26, 27, 6.581227521896363), + 4096: (26, 27, 7.1324774312973025), +} def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Aligns the token distribution across experts to be compatible with block - size for matrix multiplication. - - Parameters: - - topk_ids: A tensor of shape [total_tokens, top_k] representing the - top-k expert indices for each token. - - block_size: The block size used in block matrix multiplication. - - num_experts: The total number of experts. - - Returns: - - sorted_token_ids: A tensor containing the sorted token indices according - to their allocated expert. - - expert_ids: A tensor indicating the assigned expert index for each block. - - num_tokens_post_padded: The total number of tokens after padding, - ensuring divisibility by block_size. - - This function pads the number of tokens that each expert needs to process - so that it is divisible by block_size. - Padding ensures that during block matrix multiplication, the dimensions - align correctly. - - Example: - Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], - block_size = 4, and num_experts = 4: - - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, - with each expert needing to process 3 tokens. - - As block_size is 4, we pad 1 token for each expert. - - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - - Then append padding tokens [12, 12, 12, 12] for each block. - - After sorting by expert index, we obtain token_ids - [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - Tokens 12 are non-existent (padding) and are ignored in - the subsequent matrix multiplication. - - The padding ensures that the total number of tokens is now divisible - by block_size for proper block matrix operations. - """ sorted_ids = torch.empty( (topk_ids.numel() + num_experts * (block_size - 1),), dtype=torch.int32, @@ -239,114 +88,34 @@ def moe_align_block_size( ) sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + expert_off = torch.empty( + (num_experts + 1), dtype=torch.int32, device=topk_ids.device ) - return sorted_ids, expert_ids, num_tokens_post_pad - - -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], - compute_type: tl.dtype, - use_fp8: bool, -) -> None: - assert topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - - if not use_fp8: - assert A_scale is None - assert B_scale is None - else: - # A, A_scale = ops.scaled_fp8_quant(A, A_scale) - assert B_scale is not None - - # A = triton.reinterpret(A, dtype=tl.uint8) if use_fp8 else A - B = triton.reinterpret(B, dtype=tl.uint8) if use_fp8 else B - - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + expert_length = torch.empty( + (num_experts + 1), dtype=torch.int32, device=topk_ids.device ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, + ampere_fp8.ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8=use_fp8, - **config, + num_tokens_post_pad, + expert_off, + expert_length, ) - - -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: - device_name = torch.cuda.get_device_name().replace(" ", "_") - dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" - - -@functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: - """ - Return optimized configurations for the fused MoE kernel. - - The return value will be a dictionary that maps an irregular grid of - batch sizes to configurations of the fused_moe kernel. To evaluate the - kernel on a given batch size bs, the closest batch size in the grid should - be picked and the associated configuration chosen to invoke the kernel. - """ - - # First look up if an optimized configuration is available in the configs - # directory - # json_file_name = get_config_file_name(E, N, dtype) - json_file_name = "E=16,N=6400,device_name=NVIDIA_A100_80GB_PCIe.json" - - config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + return ( + sorted_ids, + expert_ids, + num_tokens_post_pad, + expert_off.to(torch.int64), + expert_length, ) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", config_file_path) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} - - # If no optimized configuration is available, we will use the default - # configuration - return None def fused_moe( - hidden_states: torch.Tensor, + activation: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, gating_output: torch.Tensor, @@ -360,160 +129,113 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, routing_func: Callable = torch.topk, + cfg_id_0=-1, + cfg_id_1=-1, ) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. + hidden_states_dtype = activation.dtype + hidden_states = activation.to(torch.float16) - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - M, _ = hidden_states.shape - E, N, _ = w1.shape + M, K = hidden_states.shape + E, _, N = w1.shape + block_m = 16 + block_k = 128 - if routing_func != torch.topk: - topk_weights, topk_ids = routing_func(gating_output, topk) - elif is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32) - topk_weights, topk_ids = routing_func(routing_weights, topk) - else: - import vllm._moe_C as moe_kernels + if cfg_id_0 < 1 or cfg_id_1 < 1: + cfg_id_0, cfg_id_1, _ = moe_gg_kernel_config[min(moe_gg_kernel_config.keys(), key=lambda x: abs(x - M))] - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights, topk_ids = routing_func(gating_output, topk) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } + sorted_token_ids, expert_ids, num_tokens_post_padded, expert_off, expert_length = ( + moe_align_block_size(topk_ids, block_m, E) + ) - if M <= E: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } + intermediate_cache3 = torch.empty( + (M, topk, K), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), + gathered_cache = torch.empty( + (sorted_token_ids.size(0), K), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), + + gathered_cache_1 = torch.empty( + (sorted_token_ids.size(0), N), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), + + gathered_cache_2 = torch.empty( + (sorted_token_ids.size(0), N // 2), device=hidden_states.device, dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E + gathered_cache_3 = torch.empty( + (sorted_token_ids.size(0), K), + device=hidden_states.device, + dtype=hidden_states.dtype, ) - compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel( + # hidden states -> sorted hidden states + gather_scatter_kernel.invoke_moe_gather( hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, + gathered_cache, sorted_token_ids, - expert_ids, num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, + topk_ids, + block_m, + block_k, + topk, + 4, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + total_rows_before_expert = expert_off[1:] - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, + moe_kernel.grouped_gemm( + gathered_cache, + w1, + w1_scale, + total_rows_before_expert, + gathered_cache_1, + 5, + cfg_id_0, + ) + + ops.silu_and_mul(gathered_cache_2, gathered_cache_1.view(-1, N)) + + moe_kernel.grouped_gemm( + gathered_cache_2.view(torch.float16), + w2.view(torch.int8), + w2_scale.view(hidden_states.dtype), + total_rows_before_expert, + gathered_cache_3, + 5, + cfg_id_1, + ) + + gather_scatter_kernel.invoke_moe_scatter( + gathered_cache_3, + intermediate_cache3.view(-1, K), sorted_token_ids, - expert_ids, num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, + topk_ids, + block_m, + block_k, + topk, + 4, + topk_weights=topk_weights, ) + intermediate_cache3 = intermediate_cache3[:M, :, :].to(hidden_states_dtype) + if inplace: return torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, - out=hidden_states, + out=activation, ) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/vllm/model_executor/layers/fused_moe/csrc/cuda_compat.h b/vllm/model_executor/layers/fused_moe/csrc/cuda_compat.h new file mode 100644 index 0000000000000..c711d8d1b24b9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/csrc/cuda_compat.h @@ -0,0 +1,38 @@ +#pragma once + +#ifdef USE_ROCM +#include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif + diff --git a/vllm/model_executor/layers/fused_moe/csrc/dispatch_utils.h b/vllm/model_executor/layers/fused_moe/csrc/dispatch_utils.h new file mode 100644 index 0000000000000..91abd9e85b4bb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/csrc/dispatch_utils.h @@ -0,0 +1,37 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#pragma once + +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/vllm/model_executor/layers/fused_moe/csrc/moe_align_block_size_kernels.cu b/vllm/model_executor/layers/fused_moe/csrc/moe_align_block_size_kernels.cu new file mode 100644 index 0000000000000..b83bb6794591d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/csrc/moe_align_block_size_kernels.cu @@ -0,0 +1,145 @@ +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) + +namespace vllm { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} + +template +__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, + int32_t *sorted_token_ids, + int32_t *expert_ids, + int32_t *total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t *expert_offset, + int32_t *expert_length + ) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are assigned + * to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + expert_offset[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + expert_offset[i] = cumsum[i]; + expert_length[i-1] = tokens_cnts[index(num_experts, blockDim.x, i - 1)]; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding blocks + * and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token after + * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and + * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], + * where * represents a padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] + * stores the indices of the tokens processed by the expert with expert_id within + * the current thread's token shard. + */ + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} +} + +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor expert_offset, + torch::Tensor expert_length + ) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors + const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe_align_block_size_kernel; + AT_CUDA_CHECK( + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + expert_offset.data_ptr(), + expert_length.data_ptr() + ); + }); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + ops.def( + "moe_align_block_size", + &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 78e7e4616d9a6..ae2c7278b5dd9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -454,7 +454,8 @@ def fused_moe( config, compute_type=compute_type, use_fp8=use_fp8) - + return intermediate_cache1 + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, diff --git a/vllm/model_executor/layers/fused_moe/gather_scatter_kernel.py b/vllm/model_executor/layers/fused_moe/gather_scatter_kernel.py new file mode 100644 index 0000000000000..3400b3669f83b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/gather_scatter_kernel.py @@ -0,0 +1,373 @@ +import torch +import triton +import triton.language as tl + +import pytest + +import vllm +from vllm import _custom_ops as ops + +from typing import Tuple +from functools import wraps + +import torch +import functools + +def timeit_decorator(times=100): + def decorator(function_call): + @functools.wraps(function_call) + def wrapper(*args, **kwargs): + + # cuda graph + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(3): + function_call(*args, **kwargs) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_call(*args, **kwargs) + + all_time = 0.0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for j in range(times): + #function_call(*args, **kwargs) + g.replay() + + end.record() + torch.cuda.synchronize() + elapsed_time_ms = start.elapsed_time(end) + all_time = elapsed_time_ms + + avg_time = all_time / times + print(f"{function_call.__name__} average time: {avg_time} ms") + return function_call(*args, **kwargs) + + return wrapper + return decorator + + +@triton.jit +def moe_gather( + a_ptr, + c_ptr, + sorted_token_ids_ptr, + num_tokens_post_padded_ptr, + M, + K, + EM, + num_valid_tokens, + stride_am, + stride_ak, + stride_cm, + stride_ck, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + topk: tl.constexpr, + splitk: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_m = pid // splitk + pid_n = pid % splitk + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // topk * stride_am + offs_k[None, :] * stride_ak + ) + + c_ptrs = c_ptr + (offs_token_id[:, None] * stride_cm + offs_k[None, :] * stride_ck) + w_token_mask = offs_token_id < num_tokens_post_padded + + SPLITED_K = tl.cdiv(K, BLOCK_SIZE_K) // splitk + + a_ptrs = a_ptrs + pid_n * SPLITED_K * BLOCK_SIZE_K * stride_ak + c_ptrs = c_ptrs + pid_n * SPLITED_K * BLOCK_SIZE_K * stride_ck + + for k in range(pid_n * SPLITED_K, (pid_n + 1) * SPLITED_K): + a_mask = token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + c_mask = w_token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + a = tl.load( + a_ptrs, + mask=a_mask, + other=0.0, + ) + tl.store(c_ptrs, a, mask=c_mask) + + a_ptrs += BLOCK_SIZE_K * stride_ak + c_ptrs += BLOCK_SIZE_K * stride_ck + + +@triton.jit +def moe_scatter( + a_ptr, + c_ptr, + sorted_token_ids_ptr, + num_tokens_post_padded_ptr, + topk_weights_ptr, + M, + K, + EM, + num_valid_tokens, + stride_am, + stride_ak, + stride_cm, + stride_ck, + # Meta-parameters + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + topk: tl.constexpr, + splitk: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_m = pid // splitk + pid_n = pid % splitk + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + a_ptrs = a_ptr + (offs_token_id[:, None] * stride_am + offs_k[None, :] * stride_ak) + a_token_mask = offs_token_id < num_tokens_post_padded + + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + w_token_mask = offs_token < num_valid_tokens + + c_ptrs = c_ptr + (offs_token[:, None] * stride_cm + offs_k[None, :] * stride_ck) + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=w_token_mask, other=0) + + SPLITED_K = tl.cdiv(K, BLOCK_SIZE_K) // splitk + a_ptrs = a_ptrs + pid_n * SPLITED_K * BLOCK_SIZE_K * stride_ak + c_ptrs = c_ptrs + pid_n * SPLITED_K * BLOCK_SIZE_K * stride_ck + + for k in range(pid_n * SPLITED_K, (pid_n + 1) * SPLITED_K): + a_mask = a_token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + c_mask = w_token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K) + a = tl.load( + a_ptrs, + mask=a_mask, + other=0.0, + ) + if MUL_ROUTED_WEIGHT: + a = a * moe_weight[:, None] + tl.store(c_ptrs, a, mask=c_mask) + + a_ptrs += BLOCK_SIZE_K * stride_ak + c_ptrs += BLOCK_SIZE_K * stride_ck + + +def sparsemixer(scores, top_k, jitter_eps=0.01): + assert top_k == 2 + + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + sorted_ids = torch.empty( + (topk_ids.numel() + num_experts * (block_size - 1),), + dtype=torch.int32, + device=topk_ids.device, + ) + expert_ids = torch.empty( + (topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_moe_gather( + inp, + outp, + sorted_token_ids, + num_tokens_post_padded, + topk_ids, + block_m, + block_k, + topk, + splitk=1, +): + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], block_m) * splitk,) + + moe_gather[grid]( + inp, + outp, + sorted_token_ids, + num_tokens_post_padded, + inp.size(0), + inp.size(1), + sorted_token_ids.size(0), + topk_ids.numel(), + inp.stride(0), + inp.stride(1), + outp.stride(0), + outp.stride(1), + BLOCK_SIZE_M=block_m, + BLOCK_SIZE_K=block_k, + topk=topk, + splitk=splitk, + ) + + +def invoke_moe_scatter( + inp, + outp, + sorted_token_ids, + num_tokens_post_padded, + topk_ids, + block_m, + block_k, + topk, + splitk=1, + topk_weights=None, +): + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], block_m) * splitk,) + + moe_scatter[grid]( + inp, + outp, + sorted_token_ids, + num_tokens_post_padded, + topk_weights, + inp.size(0), + inp.size(1), + sorted_token_ids.size(0), + topk_ids.numel(), + inp.stride(0), + inp.stride(1), + outp.stride(0), + outp.stride(1), + topk_weights is not None, + BLOCK_SIZE_M=block_m, + BLOCK_SIZE_K=block_k, + topk=topk, + splitk=splitk, + ) + + +def test_gather_scatter(tokens=4096, hidden_size = 4096, experts = 16, block_m = 128, block_k = 128, topk = 2, splitk = 4): + hidden_states = torch.randn(tokens, hidden_size).cuda().bfloat16() + gatew = torch.randn(hidden_size, experts).cuda().half() + gating_output = torch.matmul(hidden_states.half(), gatew).float() + topk_weights, topk_ids = sparsemixer(gating_output, topk) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_m, experts + ) + + intermediate_cache1 = torch.zeros( + (sorted_token_ids.size(0), hidden_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + invoke_moe_gather( + hidden_states, + intermediate_cache1, + sorted_token_ids, + num_tokens_post_padded, + topk_ids, + block_m, + block_k, + topk, + splitk + ) + + print("hidden_states") + print(hidden_states) + print("intermediate_cache1") + print(intermediate_cache1) + + intermediate_cache2 = torch.zeros( + (tokens * topk, hidden_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + invoke_moe_scatter( + intermediate_cache1, + intermediate_cache2, + sorted_token_ids, + num_tokens_post_padded, + topk_ids, + block_m, + block_k, + topk, + splitk, + ) + + print("intermediate_cache2") + print(intermediate_cache2) + new_ic_2 = intermediate_cache2.reshape(tokens, topk, hidden_size)[:, 0, :] + + torch.testing.assert_close(hidden_states, new_ic_2) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/pycublas.zip b/vllm/model_executor/layers/fused_moe/pycublas.zip new file mode 100644 index 0000000000000..8eaad77aa9f7b Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/pycublas.zip differ diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5f72606285094..5483b0d950930 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -61,6 +61,7 @@ def is_sm80(device_id=0): return (device_properties.major == 8 and device_properties.minor == 0) if is_sm80(): + import pycublas.trtllm_moe_grouped_gemm as moe_kernel from vllm.model_executor.layers.fused_moe import ampere_fp8_fused_moe fused_moe_a100 = ampere_fp8_fused_moe.fused_moe @@ -259,6 +260,7 @@ def __init__( # quantization schemes self.use_fp8 = isinstance(quant_config, Fp8Config) self.apply_a100_fp8 = is_sm80() and self.use_fp8 + self.remove_subnormal_fp8 = False if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -347,20 +349,29 @@ def process_weights_after_loading(self): w2s[expert, :, :], self.w2s_scale[ expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) - def remove_subnormal_fp8(tensor): - assert tensor.dtype == torch.uint8, "Tensor must be a byte tensor representing fp8 values" - exponent_mask = 0b11111000 - mantissa_mask = 0b00000111 - exponents = (tensor & exponent_mask) >> 3 - mantissas = tensor & mantissa_mask - subnormal_mask = (exponents == 0) & (mantissas != 0) - if subnormal_mask.any(): - print(subnormal_mask.sum().item() / subnormal_mask.numel() * 100, "% of values are subnormal") - tensor[subnormal_mask] = 0 - return subnormal_mask.any() - - #remove_subnormal_fp8(ws.view(torch.uint8)) - #remove_subnormal_fp8(w2s.view(torch.uint8)) + if self.apply_a100_fp8 and self.remove_subnormal_fp8: + print_warning_once("Removing FP8 subnormal values from weights") + def remove_subnormal_fp8(tensor): + assert tensor.dtype == torch.uint8, "Tensor must be a byte tensor representing fp8 values" + exponent_mask = 0b11111000 + mantissa_mask = 0b00000111 + exponents = (tensor & exponent_mask) >> 3 + mantissas = tensor & mantissa_mask + subnormal_mask = (exponents == 0) & (mantissas != 0) + if subnormal_mask.any(): + print(subnormal_mask.sum().item() / subnormal_mask.numel() * 100, "% of values are subnormal") + tensor[subnormal_mask] = 0 + return subnormal_mask.any() + + remove_subnormal_fp8(ws.view(torch.uint8)) + remove_subnormal_fp8(w2s.view(torch.uint8)) + + if self.apply_a100_fp8: + print_warning_once("Preprocessing weights for A100 FP8 fused MoE") + ws = moe_kernel.preprocess_weights_for_mixed_gemm(ws.view(torch.int8).transpose(1,2).contiguous().cpu()).to(w2s.device) + w2s = moe_kernel.preprocess_weights_for_mixed_gemm(w2s.view(torch.int8).transpose(1,2).contiguous().cpu()).to(ws.device) + self.ws_scale = nn.Parameter(self.ws_scale.to(dtype=torch.float16).unsqueeze(1).expand(-1, ws.size(-1)).contiguous(), requires_grad=False) + self.w2s_scale = nn.Parameter(self.w2s_scale.to(dtype=torch.float16).unsqueeze(1).expand(-1, w2s.size(-1)).contiguous(), requires_grad=False) self.ws = nn.Parameter(ws.to("cuda"), requires_grad=False) self.w2s = nn.Parameter(w2s.to("cuda"), requires_grad=False)