|
| 1 | +"""Fused MoE kernel with FP8 weight using Ampere.""" |
| 2 | + |
| 3 | +import vllm |
| 4 | +import torch |
| 5 | +from vllm import _custom_ops as ops |
| 6 | +import cupy |
| 7 | +from typing import Dict, Any, Optional, Callable |
| 8 | + |
| 9 | +import torch |
| 10 | +import triton |
| 11 | +import triton.language as tl |
| 12 | + |
| 13 | +from vllm import _custom_ops as ops |
| 14 | +from vllm.logger import init_logger |
| 15 | +from vllm.utils import is_hip |
| 16 | + |
| 17 | +logger = init_logger(__name__) |
| 18 | + |
| 19 | +from vllm.model_executor.layers.fused_moe.fused_moe import ( |
| 20 | + get_moe_configs, |
| 21 | + moe_align_block_size, |
| 22 | + invoke_fused_moe_kernel, |
| 23 | +) |
| 24 | + |
| 25 | +# <todo:wenxh> Kernels performance needs to be optimized |
| 26 | +# such as one thread deals with multiple elements to reduce memory transaction. |
| 27 | + |
| 28 | +convert_fp8e4m3_to_half = cupy.RawKernel( |
| 29 | + r""" |
| 30 | +#include "cuda_fp8.h" |
| 31 | +#include "cuda_fp16.h" |
| 32 | +extern "C" __global__ |
| 33 | +void convert_fp8e4m3_to_half(const __nv_fp8_storage_t* x, float *scale_p, half* y, int size) { |
| 34 | + int tid = blockDim.x * blockIdx.x + threadIdx.x; |
| 35 | + float scale = *scale_p; |
| 36 | + if (tid < size) |
| 37 | + y[tid] = __nv_cvt_fp8_to_halfraw(x[tid], __NV_E4M3) * scale; |
| 38 | +} |
| 39 | +""", |
| 40 | + "convert_fp8e4m3_to_half", |
| 41 | +) |
| 42 | + |
| 43 | +convert_fp8e4m3_to_bfloat16 = cupy.RawKernel( |
| 44 | + r""" |
| 45 | +#include "cuda_fp8.h" |
| 46 | +#include "cuda_fp16.h" |
| 47 | +#include "cuda_bf16.h" |
| 48 | +extern "C" __global__ |
| 49 | +void convert_fp8e4m3_to_bfloat16(const __nv_fp8_storage_t* x, float* scale_p, __nv_bfloat16* y, int size) { |
| 50 | + int tid = blockDim.x * blockIdx.x + threadIdx.x; |
| 51 | + float scale = *scale_p; |
| 52 | + if (tid < size) |
| 53 | + y[tid] = __float2bfloat16(__nv_cvt_fp8_to_halfraw(x[tid], __NV_E4M3) * scale); |
| 54 | +} |
| 55 | +""", |
| 56 | + "convert_fp8e4m3_to_bfloat16", |
| 57 | +) |
| 58 | + |
| 59 | + |
| 60 | +def dequantize_fp8(t_fp8, scales, dtype=torch.float16): |
| 61 | + s = torch.empty_like(t_fp8, dtype=dtype) |
| 62 | + convert = ( |
| 63 | + convert_fp8e4m3_to_half |
| 64 | + if dtype == torch.float16 |
| 65 | + else convert_fp8e4m3_to_bfloat16 |
| 66 | + ) |
| 67 | + |
| 68 | + expert_num = t_fp8.shape[0] |
| 69 | + |
| 70 | + expert_in = torch.chunk(t_fp8, expert_num, dim=0) |
| 71 | + expert_out = torch.chunk(s, expert_num, dim=0) |
| 72 | + |
| 73 | + for i in range(expert_num): |
| 74 | + scale = scales[i] |
| 75 | + convert( |
| 76 | + ((expert_in[i].numel() + 1024 - 1) // 1024,), |
| 77 | + (1024,), |
| 78 | + (expert_in[i].data_ptr(), scale.data_ptr(), expert_out[i].data_ptr(), t_fp8.numel()), |
| 79 | + ) |
| 80 | + return s |
| 81 | + |
| 82 | + |
| 83 | +def fused_moe( |
| 84 | + hidden_states: torch.Tensor, |
| 85 | + w1: torch.Tensor, |
| 86 | + w2: torch.Tensor, |
| 87 | + gating_output: torch.Tensor, |
| 88 | + topk: int, |
| 89 | + renormalize: bool, |
| 90 | + training: bool = False, |
| 91 | + sparse_mixer: bool = False, |
| 92 | + inplace: bool = False, |
| 93 | + override_config: Optional[Dict[str, Any]] = None, |
| 94 | + use_fp8: bool = False, |
| 95 | + w1_scale: Optional[torch.Tensor] = None, |
| 96 | + w2_scale: Optional[torch.Tensor] = None, |
| 97 | + a1_scale: Optional[torch.Tensor] = None, |
| 98 | + a2_scale: Optional[torch.Tensor] = None, |
| 99 | + routing_func: Callable = torch.topk, |
| 100 | +) -> torch.Tensor: |
| 101 | + """ |
| 102 | + This function computes a Mixture of Experts (MoE) layer using two sets of |
| 103 | + weights, w1 and w2, and top-k gating mechanism. |
| 104 | +
|
| 105 | + This layer works the same as fused_moe, but it is used for the Ampere arch, which does not support fp8. |
| 106 | + By default, to be more comparable to Hopper, we reuse E4M3 configuration. |
| 107 | + <todo:wenxh> Use FP8E4b16 to reduce overhead: |
| 108 | + https://github.com/triton-lang/triton/blob/d7c8b3d7890125f5fc1b9f046e3189baa2665be4/python/triton/language/extra/cuda/utils.py#L34 |
| 109 | +
|
| 110 | + Parameters: |
| 111 | + - hidden_states (torch.Tensor): The input tensor to the MoE layer. |
| 112 | + - w1 (torch.Tensor): The first set of expert weights. |
| 113 | + - w2 (torch.Tensor): The second set of expert weights. |
| 114 | + - gating_output (torch.Tensor): The output of the gating operation |
| 115 | + (before softmax). |
| 116 | + - topk (int): The number of top-k experts to select. |
| 117 | + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. |
| 118 | + - inplace (bool): If True, perform the operation in-place. |
| 119 | + Defaults to False. |
| 120 | + - override_config (Optional[Dict[str, Any]]): Optional override |
| 121 | + for the kernel configuration. |
| 122 | + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner |
| 123 | + products for w1 and w2. Defaults to False. |
| 124 | + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for |
| 125 | + w1. |
| 126 | + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for |
| 127 | + w2. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + - torch.Tensor: The output tensor after applying the MoE layer. |
| 131 | + """ |
| 132 | + # Check constraints. |
| 133 | + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" |
| 134 | + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" |
| 135 | + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" |
| 136 | + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
| 137 | + assert w1.is_contiguous(), "Expert weights1 must be contiguous" |
| 138 | + assert w2.is_contiguous(), "Expert weights2 must be contiguous" |
| 139 | + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] |
| 140 | + M, _ = hidden_states.shape |
| 141 | + E, N, _ = w1.shape |
| 142 | + |
| 143 | + if routing_func != torch.topk: |
| 144 | + topk_weights, topk_ids = routing_func(gating_output, topk) |
| 145 | + elif is_hip(): |
| 146 | + # The MoE kernels are not yet supported on ROCm. |
| 147 | + routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32) |
| 148 | + topk_weights, topk_ids = routing_func(routing_weights, topk) |
| 149 | + else: |
| 150 | + import vllm._moe_C as moe_kernels |
| 151 | + |
| 152 | + topk_weights = torch.empty( |
| 153 | + M, topk, dtype=torch.float32, device=hidden_states.device |
| 154 | + ) |
| 155 | + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) |
| 156 | + token_expert_indicies = torch.empty( |
| 157 | + M, topk, dtype=torch.int32, device=hidden_states.device |
| 158 | + ) |
| 159 | + moe_kernels.topk_softmax( |
| 160 | + topk_weights, |
| 161 | + topk_ids, |
| 162 | + token_expert_indicies, |
| 163 | + gating_output.float(), # TODO(woosuk): Optimize this. |
| 164 | + ) |
| 165 | + del token_expert_indicies # Not used. Will be used in the future. |
| 166 | + if renormalize: |
| 167 | + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
| 168 | + |
| 169 | + if override_config: |
| 170 | + config = override_config |
| 171 | + else: |
| 172 | + # First try to load optimal config from the file |
| 173 | + configs = get_moe_configs(E, w2.shape[2], None) |
| 174 | + |
| 175 | + if configs: |
| 176 | + # If an optimal configuration map has been found, look up the |
| 177 | + # optimal config |
| 178 | + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] |
| 179 | + else: |
| 180 | + # Else use the default config |
| 181 | + config = { |
| 182 | + "BLOCK_SIZE_M": 64, |
| 183 | + "BLOCK_SIZE_N": 64, |
| 184 | + "BLOCK_SIZE_K": 32, |
| 185 | + "GROUP_SIZE_M": 8, |
| 186 | + } |
| 187 | + |
| 188 | + if M <= E: |
| 189 | + config = { |
| 190 | + "BLOCK_SIZE_M": 16, |
| 191 | + "BLOCK_SIZE_N": 32, |
| 192 | + "BLOCK_SIZE_K": 64, |
| 193 | + "GROUP_SIZE_M": 1, |
| 194 | + } |
| 195 | + |
| 196 | + if M == 1: |
| 197 | + # expert, hs1, hs2 |
| 198 | + topk_w1 = w1.view(torch.uint8)[topk_ids.flatten()] |
| 199 | + topk_w2 = w2.view(torch.uint8)[topk_ids.flatten()] |
| 200 | + topk_ids = torch.arange( |
| 201 | + topk, device=topk_ids.device, dtype=topk_ids.dtype |
| 202 | + ).unsqueeze(0) |
| 203 | + |
| 204 | + E = topk |
| 205 | + |
| 206 | + w1_scale = w1_scale[topk_ids.flatten()] |
| 207 | + w1 = dequantize_fp8(topk_w1, w1_scale, dtype=hidden_states.dtype) |
| 208 | + |
| 209 | + else: |
| 210 | + w1 = dequantize_fp8(w1, w1_scale, dtype=hidden_states.dtype) |
| 211 | + |
| 212 | + use_fp8 = False |
| 213 | + w1_scale = None |
| 214 | + a1_scale = None |
| 215 | + a2_scale = None |
| 216 | + |
| 217 | + intermediate_cache1 = torch.empty( |
| 218 | + (M, topk_ids.shape[1], N), |
| 219 | + device=hidden_states.device, |
| 220 | + dtype=hidden_states.dtype, |
| 221 | + ) |
| 222 | + intermediate_cache2 = torch.empty( |
| 223 | + (M * topk_ids.shape[1], N // 2), |
| 224 | + device=hidden_states.device, |
| 225 | + dtype=hidden_states.dtype, |
| 226 | + ) |
| 227 | + intermediate_cache3 = torch.empty( |
| 228 | + (M, topk_ids.shape[1], w2.shape[1]), |
| 229 | + device=hidden_states.device, |
| 230 | + dtype=hidden_states.dtype, |
| 231 | + ) |
| 232 | + |
| 233 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 234 | + topk_ids, config["BLOCK_SIZE_M"], E |
| 235 | + ) |
| 236 | + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 |
| 237 | + |
| 238 | + invoke_fused_moe_kernel( |
| 239 | + hidden_states, |
| 240 | + w1, |
| 241 | + intermediate_cache1, |
| 242 | + a1_scale, |
| 243 | + w1_scale, |
| 244 | + topk_weights, |
| 245 | + topk_ids, |
| 246 | + sorted_token_ids, |
| 247 | + expert_ids, |
| 248 | + num_tokens_post_padded, |
| 249 | + False, |
| 250 | + topk_ids.shape[1], |
| 251 | + config, |
| 252 | + compute_type=compute_type, |
| 253 | + use_fp8=use_fp8, |
| 254 | + ) |
| 255 | + |
| 256 | + del w1 |
| 257 | + |
| 258 | + if M == 1: |
| 259 | + w2_scale = w2_scale[topk_ids.flatten()] |
| 260 | + w2 = dequantize_fp8(topk_w2, w2_scale, dtype=hidden_states.dtype) |
| 261 | + else: |
| 262 | + w2 = dequantize_fp8(w2, w2_scale, dtype=hidden_states.dtype) |
| 263 | + |
| 264 | + w2_scale = None |
| 265 | + |
| 266 | + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
| 267 | + |
| 268 | + invoke_fused_moe_kernel( |
| 269 | + intermediate_cache2, |
| 270 | + w2, |
| 271 | + intermediate_cache3, |
| 272 | + a2_scale, |
| 273 | + w2_scale, |
| 274 | + topk_weights, |
| 275 | + topk_ids, |
| 276 | + sorted_token_ids, |
| 277 | + expert_ids, |
| 278 | + num_tokens_post_padded, |
| 279 | + True, |
| 280 | + 1, |
| 281 | + config, |
| 282 | + compute_type=compute_type, |
| 283 | + use_fp8=use_fp8, |
| 284 | + ) |
| 285 | + |
| 286 | + del w2 |
| 287 | + |
| 288 | + if inplace: |
| 289 | + return torch.sum( |
| 290 | + intermediate_cache3.view(*intermediate_cache3.shape), |
| 291 | + dim=1, |
| 292 | + out=hidden_states, |
| 293 | + ) |
| 294 | + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) |
0 commit comments