|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Fused MoE kernel.""" |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm import _custom_ops as ops |
| 8 | + |
| 9 | + |
| 10 | +#TODO make the grouped gemm kernel consistent with scaled gemm kernel |
| 11 | +def cutlass_moe_fp8( |
| 12 | + a: torch.Tensor, |
| 13 | + w1_q: torch.Tensor, |
| 14 | + w2_q: torch.Tensor, |
| 15 | + w1_scale: torch.Tensor, |
| 16 | + w2_scale: torch.Tensor, |
| 17 | + topk_weights: torch.Tensor, |
| 18 | + topk_ids: torch.Tensor, |
| 19 | + ab_strides1: torch.Tensor, |
| 20 | + c_strides1: torch.Tensor, |
| 21 | + ab_strides2: torch.Tensor, |
| 22 | + c_strides2: torch.Tensor, |
| 23 | + a1_scale: Optional[torch.Tensor] = None, |
| 24 | + a2_scale: Optional[torch.Tensor] = None, |
| 25 | + out_dtype: torch.dtype = torch.half, |
| 26 | +) -> torch.Tensor: |
| 27 | + """ |
| 28 | + This function computes a a8w8-quantized Mixture of Experts (MoE) layer |
| 29 | + using two sets of quantized weights, w1_q and w2_q, and top-k gating |
| 30 | + mechanism. The matrix multiplications are implemented with CUTLASS |
| 31 | + grouped gemm. |
| 32 | +
|
| 33 | + Parameters: |
| 34 | + - a (torch.Tensor): The input tensor to the MoE layer. |
| 35 | + Shape: [M, K] |
| 36 | + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. |
| 37 | + Shape: [num_experts, K, 2N] (the weights are passed transposed) |
| 38 | + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. |
| 39 | + Shape: [num_experts, N, K] (the weights are passed transposed) |
| 40 | + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. |
| 41 | + Shape: [num_experts] or [num_experts, 2N] |
| 42 | + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. |
| 43 | + Shape: [num_experts] or [num_experts, K] |
| 44 | + - gating_output (torch.Tensor): The output of the gating operation |
| 45 | + (before softmax). |
| 46 | + - topk_weights (torch.Tensor): The weights of each token->expert mapping. |
| 47 | + - ab_strides1 (torch.Tensor): The input and weights strides of the first |
| 48 | + grouped gemm. |
| 49 | + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. |
| 50 | + - ab_strides2 (torch.Tensor): The input and weights strides of the second |
| 51 | + grouped gemm. |
| 52 | + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. |
| 53 | + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. |
| 54 | + Shape: scalar or [M] |
| 55 | + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to |
| 56 | + quantize the intermediate result between the gemms. |
| 57 | + Shape: scalar or [M] |
| 58 | + - out_dtype (torch.Tensor): The output tensor type. |
| 59 | +
|
| 60 | + Returns: |
| 61 | + - torch.Tensor: The fp16 output tensor after applying the MoE layer. |
| 62 | + """ |
| 63 | + |
| 64 | + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" |
| 65 | + assert w1_q.dtype == torch.float8_e4m3fn |
| 66 | + assert w2_q.dtype == torch.float8_e4m3fn |
| 67 | + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" |
| 68 | + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" |
| 69 | + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" |
| 70 | + assert a1_scale is None or a1_scale.dim( |
| 71 | + ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ |
| 72 | + 0], "Input scale shape mismatch" |
| 73 | + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ |
| 74 | + 1] == w1_q.shape[2], "W1 scale shape mismatch" |
| 75 | + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ |
| 76 | + 1] == w2_q.shape[2], "W2 scale shape mismatch" |
| 77 | + assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" |
| 78 | + assert w1_q.shape[0] == w1_scale.shape[ |
| 79 | + 0], "w1 scales expert number mismatch" |
| 80 | + assert w1_q.shape[0] == w2_scale.shape[ |
| 81 | + 0], "w2 scales expert number mismatch" |
| 82 | + assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 |
| 83 | + assert ab_strides1.shape[0] == w1_q.shape[ |
| 84 | + 0], "AB Strides 1 expert number mismatch" |
| 85 | + assert c_strides1.shape[0] == w1_q.shape[ |
| 86 | + 0], "C Strides 1 expert number mismatch" |
| 87 | + assert ab_strides2.shape[0] == w2_q.shape[ |
| 88 | + 0], "AB Strides 2 expert number mismatch" |
| 89 | + assert c_strides2.shape[0] == w2_q.shape[ |
| 90 | + 0], "C Strides 2 expert number mismatch" |
| 91 | + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" |
| 92 | + |
| 93 | + num_experts = w1_q.size(0) |
| 94 | + m = a.size(0) |
| 95 | + k = w1_q.size(1) |
| 96 | + n = w2_q.size(1) |
| 97 | + |
| 98 | + topk = topk_ids.size(1) |
| 99 | + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( |
| 100 | + a2_scale.numel() != 1 if a2_scale is not None else False) |
| 101 | + |
| 102 | + a_q, a1_scale = ops.scaled_fp8_quant( |
| 103 | + a, a1_scale, use_per_token_if_dynamic=per_act_token) |
| 104 | + device = a_q.device |
| 105 | + |
| 106 | + expert_offsets = torch.empty((num_experts + 1), |
| 107 | + dtype=torch.int32, |
| 108 | + device=device) |
| 109 | + problem_sizes1 = torch.empty((num_experts, 3), |
| 110 | + dtype=torch.int32, |
| 111 | + device=device) |
| 112 | + problem_sizes2 = torch.empty((num_experts, 3), |
| 113 | + dtype=torch.int32, |
| 114 | + device=device) |
| 115 | + |
| 116 | + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) |
| 117 | + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) |
| 118 | + |
| 119 | + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, |
| 120 | + problem_sizes2, a_map, c_map, num_experts, n, |
| 121 | + k) |
| 122 | + |
| 123 | + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) |
| 124 | + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale |
| 125 | + |
| 126 | + c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) |
| 127 | + c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) |
| 128 | + |
| 129 | + ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, |
| 130 | + expert_offsets[:-1], problem_sizes1, ab_strides1, |
| 131 | + ab_strides1, c_strides1) |
| 132 | + |
| 133 | + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) |
| 134 | + torch.ops._C.silu_and_mul(intermediate, c1) |
| 135 | + |
| 136 | + intemediate_q, a2_scale = ops.scaled_fp8_quant( |
| 137 | + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) |
| 138 | + |
| 139 | + ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, |
| 140 | + expert_offsets[:-1], problem_sizes2, ab_strides2, |
| 141 | + ab_strides2, c_strides2) |
| 142 | + |
| 143 | + return (c2[c_map].view(m, topk, k) * |
| 144 | + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) |
0 commit comments