|
| 1 | + |
| 2 | +#include <torch/all.h> |
| 3 | +#include <c10/cuda/CUDAGuard.h> |
| 4 | +#include <ATen/cuda/CUDAContext.h> |
| 5 | +#include <cuda_runtime.h> |
| 6 | + |
| 7 | +#include <cuda_fp16.h> |
| 8 | +#include <cuda_bf16.h> |
| 9 | +#include "moe_wna16_utils.h" |
| 10 | + |
| 11 | +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) |
| 12 | + |
| 13 | +template <typename scalar_t, int bit, int GROUPS> |
| 14 | +__global__ void moe_wna16_gemm_kernel( |
| 15 | + const scalar_t* __restrict__ input, scalar_t* __restrict__ output, |
| 16 | + |
| 17 | + const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, |
| 18 | + const uint32_t* __restrict__ qzeros, |
| 19 | + |
| 20 | + const float* __restrict__ topk_weights, |
| 21 | + const int32_t* __restrict__ sorted_token_ids, |
| 22 | + const int32_t* __restrict__ expert_ids, |
| 23 | + const int32_t* __restrict__ num_tokens_post_pad, |
| 24 | + |
| 25 | + uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m, |
| 26 | + uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M, |
| 27 | + uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp, |
| 28 | + bool mul_topk_weight) { |
| 29 | +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 |
| 30 | + if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { |
| 31 | + return; |
| 32 | + } else { |
| 33 | +#endif |
| 34 | + |
| 35 | + using Dtype = ScalarType<scalar_t>; |
| 36 | + using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; |
| 37 | + |
| 38 | + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; |
| 39 | + |
| 40 | + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; |
| 41 | + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; |
| 42 | + |
| 43 | + const int32_t expert_id = expert_ids[blockIdx.x]; |
| 44 | + |
| 45 | + int32_t num_valid_tokens = 0; |
| 46 | + extern __shared__ uint16_t block_input_tmp[]; |
| 47 | + scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp); |
| 48 | + scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input); |
| 49 | + |
| 50 | + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory |
| 51 | + for (int m = 0; m < BLOCK_SIZE_M; m++) { |
| 52 | + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; |
| 53 | + const int32_t token_index = sorted_token_ids[offset_m]; |
| 54 | + if (token_index / top_k >= size_m) break; |
| 55 | + |
| 56 | + num_valid_tokens = m + 1; |
| 57 | + if (blockIdx.z == 0 && offset_n < size_n) |
| 58 | + output[token_index * size_n + offset_n] = Dtype::int2num(0); |
| 59 | + |
| 60 | + if (expert_id != -1) { |
| 61 | + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); |
| 62 | + for (int i = 0; i < k_per_thread; i++) { |
| 63 | + int k = BLOCK_SIZE_N * i + threadIdx.x; |
| 64 | + if (k >= BLOCK_SIZE_K) break; |
| 65 | + if (offset_k + k >= size_k) break; |
| 66 | + |
| 67 | + // load input to shared memory |
| 68 | + // use a special layout to fit the layout of dequanted-weight |
| 69 | + int origin_k; |
| 70 | + if constexpr (bit == 4) { |
| 71 | + // [0, 4, 1, 5, 2, 6, 3, 7] |
| 72 | + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); |
| 73 | + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; |
| 74 | + } else { |
| 75 | + // [0, 2, 1, 3] |
| 76 | + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); |
| 77 | + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; |
| 78 | + } |
| 79 | + |
| 80 | + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; |
| 81 | + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + if (expert_id == -1) return; |
| 87 | + __syncthreads(); |
| 88 | + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; |
| 89 | + |
| 90 | + float res[64]; // assume BLOCK_SIZE_M <= 64 |
| 91 | + scalar_t2 res2; |
| 92 | + scalar_t2 scale_f2; |
| 93 | + scalar_t2 qzero_f2; |
| 94 | + |
| 95 | + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 |
| 96 | + constexpr int8_t pack_factor = 32 / bit; |
| 97 | + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; |
| 98 | + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; |
| 99 | + const scalar_t* expert_scales = scales + expert_offset / group_size; |
| 100 | + const uint32_t* expert_qzeros = |
| 101 | + qzeros + expert_offset / group_size / pack_factor; |
| 102 | + |
| 103 | + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 |
| 104 | + // weight would be loaded in loop |
| 105 | + uint32_t expert_qweight_tmp[4]; |
| 106 | + float4* expert_qweight_tmp_float4 = |
| 107 | + reinterpret_cast<float4*>(expert_qweight_tmp); |
| 108 | + |
| 109 | + // load all required scales one time |
| 110 | + scalar_t expert_scales_groups[GROUPS]; |
| 111 | + int scales_offset_tmp = |
| 112 | + (offset_n * size_k + offset_k) / group_size / GROUPS; |
| 113 | + if constexpr (GROUPS == 1) { |
| 114 | + *expert_scales_groups = expert_scales[scales_offset_tmp]; |
| 115 | + } else if constexpr (GROUPS == 2) { |
| 116 | + float* expert_scales_groups_tmp = |
| 117 | + reinterpret_cast<float*>(expert_scales_groups); |
| 118 | + *expert_scales_groups_tmp = |
| 119 | + reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp]; |
| 120 | + } else if constexpr (GROUPS == 4) { |
| 121 | + float2* expert_scales_groups_tmp = |
| 122 | + reinterpret_cast<float2*>(expert_scales_groups); |
| 123 | + *expert_scales_groups_tmp = |
| 124 | + reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp]; |
| 125 | + } else if constexpr (GROUPS == 8) { |
| 126 | + float4* expert_scales_groups_tmp = |
| 127 | + reinterpret_cast<float4*>(expert_scales_groups); |
| 128 | + *expert_scales_groups_tmp = |
| 129 | + reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp]; |
| 130 | + } |
| 131 | + |
| 132 | + // load all required qzeros one time |
| 133 | + uint8_t expert_qzeros_groups[GROUPS]; |
| 134 | + if (!has_zp) { |
| 135 | + if constexpr (bit == 4) { |
| 136 | + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); |
| 137 | + } else { |
| 138 | + qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); |
| 139 | + } |
| 140 | + } else { |
| 141 | + int qzeros_offset_tmp = |
| 142 | + (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + |
| 143 | + offset_k / group_size / GROUPS; |
| 144 | + if constexpr (GROUPS == 1) { |
| 145 | + uint8_t* expert_qzeros_groups_tmp = |
| 146 | + reinterpret_cast<uint8_t*>(expert_qzeros_groups); |
| 147 | + *expert_qzeros_groups_tmp = |
| 148 | + reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp]; |
| 149 | + } else if constexpr (GROUPS == 2) { |
| 150 | + uint16_t* expert_qzeros_groups_tmp = |
| 151 | + reinterpret_cast<uint16_t*>(expert_qzeros_groups); |
| 152 | + *expert_qzeros_groups_tmp = |
| 153 | + reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp]; |
| 154 | + } else if constexpr (GROUPS == 4) { |
| 155 | + uint32_t* expert_qzeros_groups_tmp = |
| 156 | + reinterpret_cast<uint32_t*>(expert_qzeros_groups); |
| 157 | + *expert_qzeros_groups_tmp = |
| 158 | + reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp]; |
| 159 | + } else if constexpr (GROUPS == 8) { |
| 160 | + uint64_t* expert_qzeros_groups_tmp = |
| 161 | + reinterpret_cast<uint64_t*>(expert_qzeros_groups); |
| 162 | + *expert_qzeros_groups_tmp = |
| 163 | + reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp]; |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { |
| 168 | + int k = offset_k + tmp_k * pack_factor; |
| 169 | + if (k >= size_k) break; |
| 170 | + const int32_t weight_offset = offset_n * size_k + k; |
| 171 | + |
| 172 | + if (tmp_k % 4 == 0) { |
| 173 | + *expert_qweight_tmp_float4 = reinterpret_cast<const float4*>( |
| 174 | + expert_qweight)[weight_offset / pack_factor / 4]; |
| 175 | + } |
| 176 | + |
| 177 | + if (tmp_k % (group_size / pack_factor) == 0) { |
| 178 | + scalar_t scale_f = |
| 179 | + expert_scales_groups[tmp_k / (group_size / pack_factor)]; |
| 180 | + scale_f2 = Dtype::num2num2(scale_f); |
| 181 | + |
| 182 | + if (has_zp) { |
| 183 | + uint8_t qzero = |
| 184 | + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; |
| 185 | + if constexpr (bit == 4) { |
| 186 | + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; |
| 187 | + } |
| 188 | + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + scalar_t2 weight_half2[16 / bit]; |
| 193 | + dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2); |
| 194 | + |
| 195 | + for (int m = 0; m < num_valid_tokens; m++) { |
| 196 | + res2 = {}; |
| 197 | + |
| 198 | +#pragma unroll |
| 199 | + for (int i = 0; i < 16 / bit; i++) { |
| 200 | + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; |
| 201 | + res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), |
| 202 | + block_input_half2[offset_input], res2); |
| 203 | + } |
| 204 | + |
| 205 | + if (tmp_k == 0) { |
| 206 | + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); |
| 207 | + } else { |
| 208 | + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); |
| 209 | + } |
| 210 | + } |
| 211 | + } |
| 212 | + |
| 213 | + for (int m = 0; m < num_valid_tokens; ++m) { |
| 214 | + const int32_t token_index = |
| 215 | + sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; |
| 216 | + if (mul_topk_weight) { |
| 217 | + res[m] *= topk_weights[token_index]; |
| 218 | + } |
| 219 | + atomicAdd(&output[token_index * size_n + offset_n], |
| 220 | + Dtype::float2num(res[m])); |
| 221 | + } |
| 222 | + |
| 223 | +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 |
| 224 | + } |
| 225 | +#endif |
| 226 | +} |
| 227 | + |
| 228 | +template <typename scalar_t> |
| 229 | +void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, |
| 230 | + const uint32_t* b_qweight, const scalar_t* b_scales, |
| 231 | + const uint32_t* b_qzeros, const float* topk_weights, |
| 232 | + const int32_t* sorted_token_ids, |
| 233 | + const int32_t* expert_ids, |
| 234 | + const int32_t* num_tokens_post_pad, int num_experts, |
| 235 | + int group_size, int num_token_blocks, int top_k, |
| 236 | + int size_m, int size_n, int size_k, int BLOCK_SIZE_M, |
| 237 | + int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit, |
| 238 | + bool has_zp, bool mul_topk_weight) { |
| 239 | + dim3 blockDim, gridDim; |
| 240 | + blockDim.x = BLOCK_SIZE_N; |
| 241 | + blockDim.y = 1; |
| 242 | + blockDim.z = 1; |
| 243 | + gridDim.x = num_token_blocks; |
| 244 | + gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N); |
| 245 | + gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K); |
| 246 | + |
| 247 | + auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>; |
| 248 | + if (bit == 4) { |
| 249 | + if (BLOCK_SIZE_K / group_size == 2) { |
| 250 | + kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>; |
| 251 | + } else if (BLOCK_SIZE_K / group_size == 4) { |
| 252 | + kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>; |
| 253 | + } else if (BLOCK_SIZE_K / group_size == 8) { |
| 254 | + kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>; |
| 255 | + } |
| 256 | + } else { |
| 257 | + if (BLOCK_SIZE_K / group_size == 1) { |
| 258 | + kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>; |
| 259 | + } else if (BLOCK_SIZE_K / group_size == 2) { |
| 260 | + kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>; |
| 261 | + } else if (BLOCK_SIZE_K / group_size == 4) { |
| 262 | + kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>; |
| 263 | + } else if (BLOCK_SIZE_K / group_size == 8) { |
| 264 | + kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>; |
| 265 | + } |
| 266 | + } |
| 267 | + |
| 268 | + const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2; |
| 269 | + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 270 | + kernel<<<gridDim, blockDim, shared_mem_size, stream>>>( |
| 271 | + input, output, b_qweight, b_scales, b_qzeros, topk_weights, |
| 272 | + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, |
| 273 | + group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, |
| 274 | + BLOCK_SIZE_K, has_zp, mul_topk_weight); |
| 275 | +} |
| 276 | + |
| 277 | +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, |
| 278 | + torch::Tensor b_qweight, torch::Tensor b_scales, |
| 279 | + std::optional<torch::Tensor> b_qzeros, |
| 280 | + std::optional<torch::Tensor> topk_weights, |
| 281 | + torch::Tensor sorted_token_ids, |
| 282 | + torch::Tensor expert_ids, |
| 283 | + torch::Tensor num_tokens_post_pad, int64_t top_k, |
| 284 | + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, |
| 285 | + int64_t BLOCK_SIZE_K, int64_t bit) { |
| 286 | + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); |
| 287 | + auto options = |
| 288 | + torch::TensorOptions().dtype(input.dtype()).device(input.device()); |
| 289 | + |
| 290 | + const int num_experts = b_qweight.size(0); |
| 291 | + const int size_m = input.size(0); |
| 292 | + const int size_n = b_qweight.size(1); |
| 293 | + const int size_k = input.size(1); |
| 294 | + const int group_size = size_k / b_scales.size(2); |
| 295 | + |
| 296 | + int64_t EM = sorted_token_ids.size(0); |
| 297 | + if (size_m <= BLOCK_SIZE_M) { |
| 298 | + EM = min(EM, size_m * BLOCK_SIZE_M * top_k); |
| 299 | + } |
| 300 | + const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; |
| 301 | + |
| 302 | + const uint32_t* b_qzeros_ptr; |
| 303 | + if (b_qzeros.has_value()) |
| 304 | + b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>(); |
| 305 | + const float* topk_weights_ptr; |
| 306 | + if (topk_weights.has_value()) |
| 307 | + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); |
| 308 | + |
| 309 | + int groups_per_block_row = BLOCK_SIZE_K / group_size; |
| 310 | + TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); |
| 311 | + TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, |
| 312 | + "size_k must divisible by BLOCK_SIZE_K"); |
| 313 | + TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, |
| 314 | + "BLOCK_SIZE_K must divisible by group_size"); |
| 315 | + TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); |
| 316 | + TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || |
| 317 | + groups_per_block_row == 4 || groups_per_block_row == 8, |
| 318 | + "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); |
| 319 | + |
| 320 | + if (input.scalar_type() == at::ScalarType::Half) { |
| 321 | + run_moe_wna16_gemm<half>( |
| 322 | + (const half*)input.data_ptr<at::Half>(), |
| 323 | + (half*)output.data_ptr<at::Half>(), |
| 324 | + (const uint32_t*)b_qweight.data_ptr<uint8_t>(), |
| 325 | + (const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr, |
| 326 | + topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(), |
| 327 | + expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), |
| 328 | + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, |
| 329 | + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, |
| 330 | + b_qzeros.has_value(), topk_weights.has_value()); |
| 331 | + } else if (input.scalar_type() == at::ScalarType::BFloat16) { |
| 332 | + run_moe_wna16_gemm<nv_bfloat16>( |
| 333 | + (const nv_bfloat16*)input.data_ptr<at::BFloat16>(), |
| 334 | + (nv_bfloat16*)output.data_ptr<at::BFloat16>(), |
| 335 | + (const uint32_t*)b_qweight.data_ptr<uint8_t>(), |
| 336 | + (const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr, |
| 337 | + topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(), |
| 338 | + expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), |
| 339 | + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, |
| 340 | + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, |
| 341 | + b_qzeros.has_value(), topk_weights.has_value()); |
| 342 | + } else { |
| 343 | + TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16"); |
| 344 | + } |
| 345 | + return output; |
| 346 | +} |
0 commit comments