From 4007cd8ab9465de19f308b6afbdbc6b8bdb75a6c Mon Sep 17 00:00:00 2001 From: wyj371990 Date: Thu, 23 Jan 2025 16:15:59 +0800 Subject: [PATCH 1/3] [Misc][Kernel]: Add GPTQAllSpark Quantization Signed-off-by: wyj371990 --- CMakeLists.txt | 16 + benchmarks/kernels/benchmark_marlin.py | 49 +- .../gptq_allspark/allspark_qgemm_w8a16.cu | 1021 +++++++++++++++++ .../gptq_allspark/allspark_repack.cu | 163 +++ .../gptq_allspark/allspark_utils.cuh | 408 +++++++ csrc/torch_bindings.cpp | 20 + tests/kernels/test_allspark_gemm.py | 103 ++ vllm/_custom_ops.py | 79 ++ vllm/model_executor/layers/linear.py | 9 +- .../schemes/compressed_tensors_wNa16.py | 4 +- .../layers/quantization/gptq_marlin.py | 4 +- .../kernels/mixed_precision/MPLinearKernel.py | 4 +- .../kernels/mixed_precision/__init__.py | 3 + .../kernels/mixed_precision/allspark.py | 125 ++ .../quantization/utils/allspark_utils.py | 51 + .../layers/vocab_parallel_embedding.py | 3 +- 16 files changed, 2053 insertions(+), 9 deletions(-) mode change 100755 => 100644 CMakeLists.txt create mode 100644 csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu create mode 100644 csrc/quantization/gptq_allspark/allspark_repack.cu create mode 100644 csrc/quantization/gptq_allspark/allspark_utils.cuh create mode 100644 tests/kernels/test_allspark_gemm.py create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py create mode 100644 vllm/model_executor/layers/quantization/utils/allspark_utils.py diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100755 new mode 100644 index 4b569ec25f12..431f68e6608b --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -298,6 +298,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() + # Only build AllSpark kernels if we are building for at least some compatible archs. + cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}") + if (ALLSPARK_ARCHS) + set(ALLSPARK_SRCS + "csrc/quantization/gptq_allspark/allspark_repack.cu" + "csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu") + set_gencode_flags_for_srcs( + SRCS "${ALLSPARK_SRCS}" + CUDA_ARCHS "${ALLSPARK_ARCHS}") + list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}") + message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}") + else() + message(STATUS "Not building AllSpark kernels as no compatible archs found" + " in CUDA target architectures") + endif() + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index c22e66c0b0c9..8fd3a8d06f67 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -10,6 +10,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) @@ -18,12 +20,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, gptq_quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str, GPTQ_MARLIN_24_MAX_PARALLEL) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) + # AllSpark W8A16 quant + as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 and not act_order and is_k_full) + if as_supported_case: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = (sm_version >= 80 and sm_version < 90) + as_supported_case = as_supported_case and supported_arch + if supported_arch: + has_zp = False + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, + has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = \ + ops.allspark_repack_weight( + qw, s, zp, has_zp) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD + globals = { # Gen params "quant_type": quant_type, @@ -109,10 +132,21 @@ def bench_run(results: List[benchmark.Measurement], model: str, # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, + # AllSpark W8A16 params + "qw_reorder": qw_reorder if as_supported_case else None, + "s_reorder": s_reorder if as_supported_case else None, + "zp_reorder": zp_reorder if as_supported_case else None, + "sm_count": sm_count if as_supported_case else None, + "sm_version": sm_version if as_supported_case else None, + "CUBLAS_M_THRESHOLD": + CUBLAS_M_THRESHOLD if as_supported_case else None, + "weight_name_pattern": + f'model.layers.k{size_k}.m{size_m}.n{size_n}.qweight', # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, + "allspark_w8a16_gemm": ops.allspark_w8a16_gemm, } min_run_time = 1 @@ -172,6 +206,17 @@ def bench_run(results: List[benchmark.Measurement], model: str, description="gptq_marlin_repack", ).blocked_autorange(min_run_time=min_run_time)) + if as_supported_case: + results.append( + benchmark.Timer( + stmt= + "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True, weight_name_pattern)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="allspark_w8a16_gemm_fp32", + ).blocked_autorange(min_run_time=min_run_time)) + def main(args): print("Benchmarking models:") diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu new file mode 100644 index 000000000000..eccc755f90b5 --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -0,0 +1,1021 @@ +#include "allspark_utils.cuh" +#include +#include "core/registration.h" +#include + +std::map as_g_output_map; // cache for 1 layer +at::Tensor as_g_workspace; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +torch::Tensor allspark_w8a16_gemm( + torch::Tensor const& a, torch::Tensor const& b_qweight, + torch::Tensor const& b_scales, c10::optional const& b_qzeros, + int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder, + std::string const& weight_name_pattern) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else +namespace allspark { +/* + * GemmTile manage data movement from Global Memory to Shared Memory + * requiring N % 8 == 0, K % 16 == 0 by loading uint + * BN is obtained by padding the original N to a multiple of 32 + * weight B is rearranged as N32K16 order, + * i.e. a initial data block of size 32(n)x16(k) is reordered as n8k4n4k4, + * in order to put data loaded by the same thread of 32x16 data block together + * continuously (see + * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type) + */ +template +struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { + // element num loaded by a LDG inst. + static constexpr int LDG_ELEMENT_CNT_A = 8; + static constexpr int LDG_ELEMENT_CNT_B = 16; + static constexpr int WARP_SIZE = 32; + static constexpr int M_SIZE_ONE_LOAD = (BLOCK * LDG_ELEMENT_CNT_A) / 32; + static constexpr int N_SIZE_ONE_LOAD = (BLOCK * LDG_ELEMENT_CNT_B) / 32; + + __device__ GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK( + const SM8x_GEMM_W8A16_Splitk_Params& k_params, + const uint32_t& A_smem_addr, const uint32_t& BQ_smem_addr, + const uint32_t& A_stage_stride, const uint32_t& BQ_stage_stride) + : params(k_params), + A_smem_base_addr(A_smem_addr), + BQ_smem_base_addr(BQ_smem_addr), + A_smem_stage_stride(A_stage_stride), + BQ_smem_stage_stride(BQ_stage_stride) { + this_block_A_base_ptr = params.A_ptr + blockIdx.x * Mtile * params.K + + blockIdx.z * params.SplitK; + // here B is rearranged as N32K16 order, i.e. 4 continuous N-direction + // 8(N)x16(K) size data blocks are packed together + this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K + + blockIdx.z * params.SplitK * 4; + + const int lane_id = threadIdx.x % WARP_SIZE; + + // For matrix A, a block load/store Mtile(row) x 32(col) elements in + // multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter + const int Aldg_row_base_idx = threadIdx.x / 4; + Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A; + const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx; + + // For matrix B, a block load/store elements of (Ntile / 4) row x 128 col + // elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row) + // * 128(col) per iter + Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B; + const int Bldg_row_base_idx = threadIdx.x / 8; + const int Bldg_base_offset = + Bldg_row_base_idx * params.K * 4 + Bldg_col_idx; + + this_block_A_base_ptr += Aldg_base_offset; + this_block_B_base_ptr += Bldg_base_offset; + + const int sts_a_base_offset = + (threadIdx.x / 4) * 32 + + ((lane_id % 4) ^ ((lane_id / 4) % 4) ^ ((lane_id / 4) / 4)) * + LDG_ELEMENT_CNT_A; + const int sts_bq_base_offset = + Bldg_row_base_idx * 32 * 4 + + ((threadIdx.x % 8) ^ (((threadIdx.x / 8) % 2) * 4)) * LDG_ELEMENT_CNT_B; + + A_smem_base_addr += sts_a_base_offset * sizeof(FType); + BQ_smem_base_addr += sts_bq_base_offset * sizeof(uint8_t); + + A_ldg_guard = 0; + B_ldg_guard = 0; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD; + if (m_idx < params.M) { + A_ldg_guard |= (1u << i); + } + } + + const int N_padded = (params.N + 31) / 32 * 32; + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 + + i * N_SIZE_ONE_LOAD; + if (n_idx < N_padded) { + B_ldg_guard |= (1u << i); + } + } + } + + __device__ void ldgsts_first_ktiles(const int& first_k_tile, + const int& k_tiles) { + // load first k_tile + // load A + const int A_src_size = Aldg_col_idx < first_k_tile ? 16 : 0; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + cp_async<16>( + A_smem_base_addr + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, A_src_size, + (A_ldg_guard & (1u << i)) != 0); + } + + // load B + const int B_src_size = (Bldg_col_idx / 4) < first_k_tile ? 16 : 0; + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + cp_async<16>( + BQ_smem_base_addr + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, B_src_size, + (B_ldg_guard & (1u << i)) != 0); + } + + cp_async_commit_group(); + this_block_A_base_ptr += first_k_tile; + this_block_B_base_ptr += (first_k_tile * 4); + + // load second to (N-stage - 1) k_tiles + for (int stage_idx = 1; stage_idx < NStage - 1; ++stage_idx) { + if (stage_idx < k_tiles) { + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; + ++i) { + cp_async<16>(A_smem_base_addr + stage_idx * A_smem_stage_stride + + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, + 16, (A_ldg_guard & (1u << i)) != 0); + } + + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; + ++i) { + cp_async<16>(BQ_smem_base_addr + stage_idx * BQ_smem_stage_stride + + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, + 16, (B_ldg_guard & (1u << i)) != 0); + } + + this_block_A_base_ptr += 32; + this_block_B_base_ptr += (32 * 4); + } + cp_async_commit_group(); + } + } + + __device__ void ldgsts(const int& sts_stage_idx) { + const int a_stage_offset = sts_stage_idx * A_smem_stage_stride; + const int bq_stage_offset = sts_stage_idx * BQ_smem_stage_stride; + #pragma unroll + for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) { + cp_async<16>(A_smem_base_addr + a_stage_offset + + (i * M_SIZE_ONE_LOAD * 32) * sizeof(FType), + this_block_A_base_ptr + i * M_SIZE_ONE_LOAD * params.K, 16, + (A_ldg_guard & (1u << i)) != 0); + } + + #pragma unroll + for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) { + cp_async<16>(BQ_smem_base_addr + bq_stage_offset + + (i * N_SIZE_ONE_LOAD * 32) * sizeof(uint8_t), + this_block_B_base_ptr + i * N_SIZE_ONE_LOAD * params.K, 16, + (B_ldg_guard & (1u << i)) != 0); + } + + cp_async_commit_group(); + this_block_A_base_ptr += 32; + this_block_B_base_ptr += (32 * 4); + } + + const FType* this_block_A_base_ptr = nullptr; + const QType* this_block_B_base_ptr = nullptr; + + int Aldg_col_idx; + int Bldg_col_idx; + + uint32_t A_ldg_guard; + uint32_t B_ldg_guard; + + uint32_t A_smem_base_addr, BQ_smem_base_addr; + const uint32_t A_smem_stage_stride, BQ_smem_stage_stride; + + const SM8x_GEMM_W8A16_Splitk_Params& params; +}; + +/* + * requiring N % 8 == 0 + */ +template +struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { + static constexpr int WARP_SIZE = 32; + static constexpr int WARP_CNT = BLOCK / WARP_SIZE; + static constexpr int WARP_NTILE = Ntile / WARP_CNT; + static constexpr int WARP_NITER = WARP_NTILE / 8; // hmma16816 + static_assert(WARP_NTILE == 32 or WARP_NTILE == 64, + "now only support WARP_NTILE = 32 or 64!"); + + __device__ ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK( + const SM8x_GEMM_W8A16_Splitk_Params& k_params, + const uint32_t& A_smem_addr, const uint32_t& BQ_smem_addr, + const uint32_t& A_stage_stride, const uint32_t& BQ_stage_stride) + : params(k_params), + A_smem_base_addr(A_smem_addr), + BQ_smem_base_addr(BQ_smem_addr), + A_smem_stage_stride(A_stage_stride), + BQ_smem_stage_stride(BQ_stage_stride) { + warp_id = threadIdx.x / WARP_SIZE; + lane_id = threadIdx.x % WARP_SIZE; + + load_a_base_offset[0] = + (lane_id % 16) * 32 + + ((lane_id / 16) ^ (lane_id % 4) ^ ((lane_id / 4) % 2)) * 8; + load_a_base_offset[1] = + (lane_id % 16) * 32 + + ((lane_id / 16 + 2) ^ (lane_id % 4) ^ ((lane_id / 4) % 2)) * 8; + + load_b_base_offset[0] = + (lane_id / 4 + warp_id * (WARP_NTILE / 4)) * 32 * 4 + + (lane_id % 4) * 16 + ((lane_id / 4) % 2) * 16 * 4; + load_b_base_offset[1] = + (lane_id / 4 + warp_id * (WARP_NTILE / 4)) * 32 * 4 + + (lane_id % 4) * 16 + (((lane_id / 4) % 2) ^ 1) * 16 * 4; + + sts_c_base_offset = warp_id * Mtile * WARP_NTILE + + (lane_id / 4) * WARP_NTILE + (lane_id % 4) * 2; + + if (EnableFuse) { + this_block_C_base_ptr = + params.C_ptr + blockIdx.x * Mtile * params.N + blockIdx.y * Ntile; + } else { + this_block_C_base_ptr = + params.C_split_ptr + blockIdx.z * params.M * params.N + + blockIdx.x * Mtile * params.N + blockIdx.y * Ntile; + } + int store_thds_in_row = WARP_NTILE / 8; + store_c_row_base_idx = lane_id / store_thds_in_row; + store_c_col_idx = warp_id * WARP_NTILE + (lane_id % store_thds_in_row) * 8; + store_c_base_offset = store_c_row_base_idx * params.N + store_c_col_idx; + + #pragma unroll + for (int i = 0; i < Mtile / 16; ++i) { + #pragma unroll + for (int j = 0; j < WARP_NITER; ++j) { + #pragma unroll + for (int k = 0; k < 4; ++k) { + C_frag[i][j][k] = 0.f; + } + } + } + params_n_idx = + blockIdx.y * Ntile + warp_id * WARP_NTILE + (lane_id / 4) * 4; + } + + __device__ void lds(const int& smem_stage_idx, const int& reg_buf_idx, + const int& k_phase_idx) { + uint32_t A_smem_addr = + A_smem_base_addr + A_smem_stage_stride * smem_stage_idx; + uint32_t B_smem_addr = + BQ_smem_base_addr + BQ_smem_stage_stride * smem_stage_idx; + + #pragma unroll + for (int i = 0; i < Mtile / 16; ++i) { + ldsm_4(A_frag[reg_buf_idx][i][0], A_frag[reg_buf_idx][i][1], + A_frag[reg_buf_idx][i][2], A_frag[reg_buf_idx][i][3], + A_smem_addr + (load_a_base_offset[k_phase_idx] + i * 16 * 32) * + sizeof(FType)); + } + #pragma unroll + for (int i = 0; i < WARP_NTILE / 32; ++i) { + lds128(BQ_frag[reg_buf_idx][4 * i + 0], BQ_frag[reg_buf_idx][4 * i + 1], + BQ_frag[reg_buf_idx][4 * i + 2], BQ_frag[reg_buf_idx][4 * i + 3], + B_smem_addr + (load_b_base_offset[k_phase_idx] + i * 32 * 32) * + sizeof(uint8_t)); + } + + // dequant B + #pragma unroll + for (int i = 0; i < WARP_NITER / 2; ++i) { + cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i], + BF_frag[reg_buf_idx][2 * i]); + if (has_zp) { + BF_frag[reg_buf_idx][2 * i][0] = + __hsub2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_zero[i].x)); + BF_frag[reg_buf_idx][2 * i][1] = + __hsub2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_zero[i].x)); + } + + BF_frag[reg_buf_idx][2 * i][0] = + __hmul2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x)); + BF_frag[reg_buf_idx][2 * i][1] = + __hmul2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x)); + + cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i + 1], + BF_frag[reg_buf_idx][2 * i + 1]); + if (has_zp) { + BF_frag[reg_buf_idx][2 * i + 1][0] = + __hsub2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_zero[i].y)); + BF_frag[reg_buf_idx][2 * i + 1][1] = + __hsub2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_zero[i].y)); + } + + BF_frag[reg_buf_idx][2 * i + 1][0] = + __hmul2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_scale[i].y)); + BF_frag[reg_buf_idx][2 * i + 1][1] = + __hmul2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_scale[i].y)); + } + } + + __device__ void ldg_params() { + const int N_padded = (params.N + 31) / 32 * 32; + // load B scale and zero_point + #pragma unroll + for (int i = 0; i < WARP_NTILE / 32; ++i) { + ldg64_ca(B_scale[2 * i + 0], B_scale[2 * i + 1], + params.B_scale_ptr + params_n_idx + i * 32, + (params_n_idx + i * 32) < N_padded); + if (has_zp) { + ldg64_ca(B_zero[2 * i + 0], B_zero[2 * i + 1], + params.B_zero_ptr + params_n_idx + i * 32, + (params_n_idx + i * 32) < N_padded); + } + } + } + + __device__ void mma(const int& reg_buf_idx) { + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + hmma16816_f32( + C_frag[m_idx][n_idx], A_frag[reg_buf_idx][m_idx], + reinterpret_cast(BF_frag[reg_buf_idx][n_idx])); + } + } + } + + __device__ void fused_splitk_reduce() { + // need splitk-reduce if enable splitk + if (gridDim.z > 1) { + int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y; + // Wait for all previous blocks in the splitk direction to accumulate the + // results into C_tmp + if (threadIdx.x == 0) { + uint32_t* red_count_ptr = params.red_count_ptr + blk_red_idx; + uint32_t count; + do { + // make sure the ld.cg inside the do-wile loop + __threadfence_block(); + asm volatile("ld.global.cg.b32 %0, [%1];" + : "=r"(count) + : "l"(red_count_ptr)); + } while (count != blockIdx.z); + } + __syncthreads(); + + int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4; + if (blockIdx.z != 0) { + // expecting that temporary register here reuses the previous A&B frag + // register + float temp_frag[Mtile / 16][WARP_NITER][4]; + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + int offset = + C_tmp_base_offset + (m_idx * WARP_NITER + n_idx) * BLOCK * 4; + *reinterpret_cast(temp_frag[m_idx][n_idx]) = + *reinterpret_cast(params.C_tmp_ptr + offset); + } + } + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + #pragma unroll + for (int idx = 0; idx < 4; ++idx) { + C_frag[m_idx][n_idx][idx] += temp_frag[m_idx][n_idx][idx]; + } + } + } + } + + // first splitk - 1 blocks need to write partial results into C_tmp + if (blockIdx.z != gridDim.z - 1) { + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + int offset = + C_tmp_base_offset + (m_idx * WARP_NITER + n_idx) * BLOCK * 4; + asm volatile( + "{st.global.cg.v4.b32 [%0], {%1, %2, %3, %4};}\n" + : + : "l"(params.C_tmp_ptr + offset), "f"(C_frag[m_idx][n_idx][0]), + "f"(C_frag[m_idx][n_idx][1]), "f"(C_frag[m_idx][n_idx][2]), + "f"(C_frag[m_idx][n_idx][3])); + } + } + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0) { + uint32_t* red_count_ptr = params.red_count_ptr + blk_red_idx; + atomicInc(red_count_ptr, gridDim.z); + } + } + } + } + + __device__ void stg(char* smem) { + if (EnableFuse) { + if (blockIdx.z != gridDim.z - 1) return; + } + uint32_t* C_sts_ptr = + reinterpret_cast(smem + sts_c_base_offset * sizeof(FType)); + // C_tile sts + #pragma unroll + for (int m_idx = 0; m_idx < Mtile / 16; ++m_idx) { + #pragma unroll + for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { + #pragma unroll + for (int k_idx = 0; k_idx < 2; ++k_idx) { + FType low16 = static_cast(C_frag[m_idx][n_idx][k_idx * 2]); + FType high16 = + static_cast(C_frag[m_idx][n_idx][k_idx * 2 + 1]); + uint32_t tmp = (reinterpret_cast(low16) & 0xffff) | + (reinterpret_cast(high16) << 16); + int sts_offset = + m_idx * 16 * (WARP_NTILE / 2) + + (((lane_id / (32 / WARP_NITER)) + n_idx) % WARP_NITER) * (8 / 2) + + k_idx * 8 * (WARP_NTILE / 2); + C_sts_ptr[sts_offset] = tmp; + } + } + } + + __syncthreads(); + + FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset; + // C_tile lds and stg + int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile; + bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N; + if (WARP_NTILE == 32) { + int lds_c_base_offset = warp_id * Mtile * WARP_NTILE + + (lane_id / 4) * WARP_NTILE + + ((lane_id % 4 + lane_id / 8) % 4) * 8; + uint4* C_lds_ptr = + reinterpret_cast(smem + lds_c_base_offset * sizeof(FType)); + #pragma unroll + for (int i = 0; i < (Mtile / 16) * (WARP_NITER / 2); ++i) { + uint4 stg_reg = C_lds_ptr[i * 8 * 4]; + stg128(stg_reg.x, stg_reg.y, stg_reg.z, stg_reg.w, + C_base_ptr + i * 8 * params.N, + (m_base_idx + i * 8) < params.M && n_guard); + } + } else if (WARP_NTILE == 64) { + int lds_c_base_offset = + warp_id * Mtile * WARP_NTILE + (lane_id / 8) * WARP_NTILE; + #pragma unroll + for (int i = 0; i < (Mtile / 16) * (WARP_NITER / 2); ++i) { + int lds_c_offset = lds_c_base_offset + i * 4 * WARP_NTILE + + ((lane_id % 8 + lane_id / 8 + (i % 2) * 4) % 8) * 8; + uint4 stg_reg = + *reinterpret_cast(smem + lds_c_offset * sizeof(FType)); + stg128(stg_reg.x, stg_reg.y, stg_reg.z, stg_reg.w, + C_base_ptr + i * 4 * params.N, + (m_base_idx + i * 4) < params.M && n_guard); + } + } + } + + const SM8x_GEMM_W8A16_Splitk_Params& params; + + int load_a_base_offset[2]; + int load_b_base_offset[2]; + int sts_c_base_offset; + + int store_c_base_offset; + + int store_c_row_base_idx, store_c_col_idx; + FType* this_block_C_base_ptr = nullptr; + + int params_n_idx; + const uint32_t A_smem_base_addr, BQ_smem_base_addr; + const uint32_t A_smem_stage_stride, BQ_smem_stage_stride; + + int lane_id; + int warp_id; + // first 2 denotes double buffer, second dim denotes M direction + uint32_t A_frag[2][Mtile / 16][4]; + + typename HalfType::T2 B_scale[WARP_NITER / 2]; + typename HalfType::T2 B_zero[WARP_NITER / 2]; + uint32_t BQ_frag[2][WARP_NITER]; + // first 2 denotes double buffer, second dim denotes N direction, last 2 + // denotes K direction + typename HalfType::T2 BF_frag[2][WARP_NITER][2]; + // first dim denotes M direction, second dim denotes N direction + float C_frag[Mtile / 16][WARP_NITER][4]; +}; + +/* + * @brief W8A16 Perchannel Quantization GEMM, + * requires N % 8 == 0, K % 16 == 0 + * accumulator precision: FP32 + * @tparam FType: DataType for A, B_scale, B_zero, and C, supports half or + * nv_bfloat16 + * @tparam QType: DataType for B, support uint8(bias128) + * @tparam Mtile: M-dimensional size of the gemm block tile, supports 16, 32, + * 48 or 64 + * @tparam Ntile: N-dimensional size of the gemm block tile, supports 128 or + * 256 + * @tparam NStage: Num of stages for async copy + * @tparam BLOCK: BLOCK size + * @tparam EnableFuse: If true, use fused splitk-reduce, otherwise use + * non-fused splitk-reduce + * @tparam has_zp: whether to use zero_point + * + * @fparam params struct consists of following parameters: + * @param A_ptr: Matrix A value ptr, A = (M, K) + * @param B_ptr: Matrix B value ptr, B = (N32_align, K) (N32K16 special + * format), N32_align = (N + 32 - 1) / 32 * 32 + * @param B_scale_ptr: B_scale value ptr, B_scale = (N32_align,) (N32K16 + * special format) + * @param B_zero_ptr: B_zero value ptr, B_zero = (N32_align,) (N32K16 + * special format) + * @param C_ptr: Matrix C value ptr, C = (M, N) + * @param M: dimnesion m + * @param N: dimnesion n + * @param K: dimnesion k + * @param SplitK: split size along K-dimension + * @param C_split_ptr: Matrix C_split value ptr, used only in non-fused + * splitk-reduce + * @param C_tmp_ptr: Matrix C_tmp value ptr, used only in fused + * splitk-reduce + * @param red_count_ptr: 1-D red_count value ptr, used only in fused + * splitk-reduce + */ +template +__global__ void __launch_bounds__(BLOCK) + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel( + const SM8x_GEMM_W8A16_Splitk_Params params) { + // A smem size = 64 * 32 * 2B/elem * 4(stage) = 16KB + // B smem size = 128 * 32 * 1B/elem * 4(stage) = 16KB + constexpr int smem_size_one_stage = Mtile * 32 * 2 + Ntile * 32; + __shared__ char smem[NStage * smem_size_one_stage]; + char* A_smem = smem; + char* BQ_smem = smem + Mtile * 32 * 2 * NStage; + + uint32_t A_smem_addr = smem_u32addr(A_smem); + uint32_t BQ_smem_addr = smem_u32addr(BQ_smem); + uint32_t A_smem_stage_stride = Mtile * 32 * 2; + uint32_t BQ_smem_stage_stride = Ntile * 32; + + // initialize the data move process from GM to SMEM for this block + GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK< + FType, QType, Mtile, Ntile, NStage, BLOCK> + gmem_tile(params, A_smem_addr, BQ_smem_addr, A_smem_stage_stride, + BQ_smem_stage_stride); + + int sts_stage_idx = 0; + int lds_stage_idx = 0; + + int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K + ? params.SplitK + : params.K - blockIdx.z * params.SplitK; + int k_tiles = (tb_k_slice + 31) / 32; + int first_k_tile = tb_k_slice - (k_tiles - 1) * 32; + + // load first three tiles to shared memory + gmem_tile.ldgsts_first_ktiles(first_k_tile, k_tiles); + sts_stage_idx += (NStage - 2); + ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK< + FType, QType, Mtile, Ntile, BLOCK, EnableFuse, has_zp> + compute_tile(params, A_smem_addr, BQ_smem_addr, A_smem_stage_stride, + BQ_smem_stage_stride); + compute_tile.ldg_params(); + cp_asyc_wait_group(); + __syncthreads(); + + compute_tile.lds(lds_stage_idx, 0, 0); + int reg_buf_idx = 1; + + // main loop + for (; k_tiles > NStage - 1; --k_tiles) { + // load next A&B tile + sts_stage_idx = sts_stage_idx < NStage - 1 ? sts_stage_idx + 1 : 0; + gmem_tile.ldgsts(sts_stage_idx); + + #pragma unroll + for (int k_phase_idx = 0; k_phase_idx < 2; k_phase_idx++) { + // dequantize next B tile + if (k_phase_idx == 1) { + cp_asyc_wait_group(); + __syncthreads(); + lds_stage_idx = lds_stage_idx < NStage - 1 ? lds_stage_idx + 1 : 0; + } + + compute_tile.lds(lds_stage_idx, reg_buf_idx, (k_phase_idx + 1) % 2); + + compute_tile.mma(reg_buf_idx ^ 1); + reg_buf_idx ^= 1; + } + } + + // last NStage-1 tiles + for (; k_tiles > 0; --k_tiles) { + cp_async_commit_group(); + #pragma unroll + for (int k_phase_idx = 0; k_phase_idx < 2; k_phase_idx++) { + // dequantize next B tile + if (k_phase_idx == 1) { + cp_asyc_wait_group(); + __syncthreads(); + lds_stage_idx = lds_stage_idx < NStage - 1 ? lds_stage_idx + 1 : 0; + } + + compute_tile.lds(lds_stage_idx, reg_buf_idx, (k_phase_idx + 1) % 2); + + compute_tile.mma(reg_buf_idx ^ 1); + reg_buf_idx ^= 1; + } + } + + if (EnableFuse) { + compute_tile.fused_splitk_reduce(); + } + compute_tile.stg(smem); +} + + #define __CALL_IF(MTILE, NTILE, NUM_THREADS, ENABLE_FUSE, HAS_ZP) \ + else if (Mtile == MTILE && Ntile == NTILE && BLOCK == NUM_THREADS && \ + enable_fuse == ENABLE_FUSE && has_zp == HAS_ZP) { \ + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_hmma16816_multistage_AN_BTN32K16_CN_splitk_kernel< \ + FType, QType, MTILE, NTILE, 4, NUM_THREADS, ENABLE_FUSE, HAS_ZP> \ + <<>>(params); \ + } + +template +void ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk( + const FType* A, const QType* B, const FType* B_scale, const FType* B_zero, + FType* C, const int M, const int N, const int K, void* workspace, + const int sm_version, const BlockTileSplitkParams& fused_gemm_params, + cudaStream_t stream) { + int Mtile = fused_gemm_params.Mtile; + int grid_x = (M + Mtile - 1) / Mtile; + int Ntile = fused_gemm_params.Ntile; + int grid_y = (N + Ntile - 1) / Ntile; + int SplitK = fused_gemm_params.SplitK; + int grid_z = (K + SplitK - 1) / SplitK; + + int BLOCK = (Ntile == 256) ? 256 : 128; + + dim3 grid(grid_x, grid_y, grid_z); + dim3 block(BLOCK); + + bool enable_fuse = fused_gemm_params.EnableFuse; + bool has_zp = B_zero != nullptr; + if (enable_fuse) { + float* C_tmp = reinterpret_cast(workspace); + uint32_t* red_count = reinterpret_cast( + (char*)workspace + grid_x * Mtile * grid_y * Ntile * sizeof(float)); + CHECK_CUDA(cudaMemsetAsync(red_count, 0, grid_x * grid_y * sizeof(uint32_t), + stream)); + SM8x_GEMM_W8A16_Splitk_Params params{ + A, B, B_scale, B_zero, C, M, N, + K, SplitK, 0, -1, nullptr, C_tmp, red_count}; + + if (false) { + } + // Select the template parameters for kernel launch + // according to the above settings. Tuning is not supported. + __CALL_IF(16, 256, 256, true, false) + __CALL_IF(32, 256, 256, true, false) + __CALL_IF(48, 256, 256, true, false) + __CALL_IF(64, 128, 128, true, false) + __CALL_IF(64, 256, 256, true, false) + __CALL_IF(16, 256, 256, true, true) + __CALL_IF(32, 256, 256, true, true) + __CALL_IF(48, 256, 256, true, true) + __CALL_IF(64, 128, 128, true, true) + __CALL_IF(64, 256, 256, true, true) + } else { + FType* C_split = reinterpret_cast(workspace); + SM8x_GEMM_W8A16_Splitk_Params params{ + A, B, B_scale, B_zero, C, M, N, + K, SplitK, 0, -1, C_split, nullptr, nullptr}; + + if (false) { + } + // Select the template parameters for kernel launch + // according to the above settings. Tuning is not supported. + __CALL_IF(16, 256, 256, false, false) + __CALL_IF(32, 256, 256, false, false) + __CALL_IF(48, 256, 256, false, false) + __CALL_IF(64, 128, 128, false, false) + __CALL_IF(64, 256, 256, false, false) + __CALL_IF(16, 256, 256, false, true) + __CALL_IF(32, 256, 256, false, true) + __CALL_IF(48, 256, 256, false, true) + __CALL_IF(64, 128, 128, false, true) + __CALL_IF(64, 256, 256, false, true) + + // SplitK reduce + f16_gemm_splitk_reduce(C_split, C, M, N, grid_z, stream); + } +} + +size_t allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size( + int m, int n, int k, int sm_count, + BlockTileSplitkParams& fused_gemm_params) { + // Determine the block tile and splitk strategy + int m16_times = (m + 16 - 1) / 16; + int Mtile = m16_times <= 4 ? m16_times * 16 : 64; + int grid_x = (m + Mtile - 1) / Mtile; + int Ntile = + (float(grid_x * ((n + 127) / 128)) / sm_count > 10) || (Mtile < 64) ? 256 + : 128; + int grid_y = (n + Ntile - 1) / Ntile; + int grid_z; + + // split-k + const float SPLIT_THRESHOLD = 0.8; + int n_slice; + for (n_slice = 1; n_slice < k / 256; ++n_slice) { + int n_block = grid_x * grid_y * n_slice; + if (n_block >= sm_count * SPLIT_THRESHOLD && + (n_block % sm_count == 0 || n_block % sm_count >= sm_count * 0.5)) { + break; + } + } + + int k_slice = + (k / n_slice) % 32 == 0 ? k / n_slice : k / n_slice / 32 * 32 + 32; + grid_z = (k + k_slice - 1) / k_slice; + bool enable_fuse = float(grid_x * grid_y) / sm_count >= 0.5 ? 1 : 0; + + size_t ws_size; + if (enable_fuse) { + ws_size = grid_x * Mtile * grid_y * Ntile * sizeof(float) // For C_tmp + + grid_x * grid_y * sizeof(uint32_t); // For red_count + } else { + ws_size = grid_z * m * n * sizeof(__half); + } + + fused_gemm_params.Mtile = Mtile; + fused_gemm_params.Ntile = Ntile; + fused_gemm_params.SplitK = k_slice; + fused_gemm_params.EnableFuse = enable_fuse; + return ws_size; +} + +// restore from N32K16 order to original N-major order +// K % 16 == 0, N % 8 == 0 +// each block process 64(k) * 32(n) result elements +template +__global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel( + const QT* qdata, const FT* scales, const FT* zeros, FT* fdata, + const int N_32align, const int N, const int K) { + __shared__ FT smem[64 * 32]; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + const int src_row_idx = blockIdx.x * 8 + lane_id / 4; + const int src_col_idx = + blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16; + const int src_offset = src_row_idx * K * 4 + src_col_idx; + int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4; + + QT qval_reg[16]; + const QT* pdata = qdata + src_offset; + if (src_col_idx < (K * 4)) { + *(reinterpret_cast(qval_reg)) = + *(reinterpret_cast(qdata + src_offset)); + } + FT scale_reg[4]; + *(reinterpret_cast(scale_reg)) = + *(reinterpret_cast(scales + params_nidx)); + FT zero_reg[4] = {0}; + if (zeros != nullptr) { + *(reinterpret_cast(zero_reg)) = + *(reinterpret_cast(zeros + params_nidx)); + } + FT fval_reg[16]; + + const int sts_base_offset = + (warp_id * 16 + (lane_id % 4) * 2) * 32 + lane_id / 4; + #pragma unroll + for (int ni = 0; ni < 4; ++ni) { + cvt_8bx4_to_16bx4_bias128( + *reinterpret_cast(&qval_reg[ni * 4]), + reinterpret_cast::T2*>(&(fval_reg[ni * 4]))); + #pragma unroll + for (int ki = 0; ki < 4; ++ki) { + fval_reg[ni * 4 + ki] = + (fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni]; + int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 + + ((ni + lane_id % 4) % 4) * 8; + smem[sts_offset] = fval_reg[ni * 4 + ki]; + } + } + __syncthreads(); + + const int lds_base_offset = + (threadIdx.x / 4) * 32 + ((threadIdx.x % 4 + threadIdx.x / 8) % 4) * 8; + #pragma unroll + for (int i = 0; i < 2; ++i) { + *reinterpret_cast(fval_reg + i * 8) = + *reinterpret_cast(smem + lds_base_offset + i * 32 * 32); + } + + const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4; + const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8; + #pragma unroll + for (int i = 0; i < 2; ++i) { + int dst_row_kidx = dst_row_base_kidx + i * 32; + int dst_offset = dst_row_kidx * N + dst_col_nidx; + if (dst_row_kidx < K && dst_col_nidx < N) { + *reinterpret_cast(fdata + dst_offset) = + *reinterpret_cast(fval_reg + i * 8); + } + } +} + +template +void restore_N32_K16_dequantize_rhs_w8a16(const QT* qdata, const FT* scales, + const FT* zeros, FT* fdata, + const int N_32align, const int N, + const int K, const int GroupSize, + cudaStream_t stream) { + TORCH_CHECK(N % 8 == 0 && K % 16 == 0 && N_32align % 32 == 0, + "Unsupported shape"); + if (GroupSize == -1) { + const int BLOCK = 128; + dim3 grid(N_32align / 32, ((K / 16) + 3) / 4); + restore_N32_K16_dequantize_rhs_w8a16_perc_kernel + <<>>(qdata, scales, zeros, fdata, N_32align, N, + K); + } + // TODO: Support SubChannel + else { + TORCH_CHECK(false, "Now only support PerChannel"); + } +} + +template +void w8a16_gemm_dq_cublas(const FT* in, const QT* rhs_qdata_ptr, + const FT* rhs_scales_ptr, const FT* rhs_zeros_ptr, + FT* out, void* workspace, const int M, + const int N_32align, const int N, const int K, + const int group_size, cudaStream_t stream, + cublasHandle_t handle) { + static_assert( + std::is_same::value || std::is_same::value, + "only float16 and bfloat16 is supported"); + // Dequant + FT* rhs_fdata_ptr = static_cast(workspace); + restore_N32_K16_dequantize_rhs_w8a16(rhs_qdata_ptr, rhs_scales_ptr, + rhs_zeros_ptr, rhs_fdata_ptr, N_32align, + N, K, group_size, stream); + // cuBLAS GEMM + int lda = K; + int ldb = N; + int ldc = N; + const float alpha = 1.0f; + const float beta = 0.0f; + cudaDataType_t cuda_type; + if (std::is_same::value) { + cuda_type = CUDA_R_16F; + } else { + cuda_type = CUDA_R_16BF; + } + CHECK_CUBLAS(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, + rhs_fdata_ptr, cuda_type, ldb, in, cuda_type, lda, + &beta, out, cuda_type, ldc, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +template +void allspark_qgemm_w8a16_perc_ampere( + const FType* A, const QType* B, const FType* B_scale, const FType* B_zero, + FType* C, const int M, const int N_32align, const int N, const int K, + void* workspace, const BlockTileSplitkParams& fused_gemm_params, + const int group_size, int CUBLAS_M_THRESHOLD, const int sm_version, + cudaStream_t stream, cublasHandle_t handle) { + if (M > CUBLAS_M_THRESHOLD) { + w8a16_gemm_dq_cublas(A, B, B_scale, B_zero, C, workspace, M, + N_32align, N, K, group_size, stream, + handle); + } else { + ampere_hgemm_W8A16_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32K16_CN_splitk< + FType, QType>(A, B, B_scale, B_zero, C, M, N, K, workspace, sm_version, + fused_gemm_params, stream); + } +} + +} // namespace allspark + +torch::Tensor allspark_w8a16_gemm( + torch::Tensor const& a, torch::Tensor const& b_qweight, + torch::Tensor const& b_scales, c10::optional const& b_qzeros, + int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder, + std::string const& weight_name_pattern) { + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + if (has_zp) { + TORCH_CHECK(b_qzeros.value().device().is_cuda(), "b_qzeros is not on GPU"); + TORCH_CHECK(b_qzeros.value().is_contiguous(), "b_qzeros is not contiguous"); + } + + int m = a.size(0); + int n_32align = (n + 32 - 1) / 32 * 32; + int k = a.size(1); + + // Verify shape + TORCH_CHECK(b_qweight.size(0) == n_32align, + "Shape mismatch: b_qweight.size(0) = ", b_qweight.size(0), + ", n_32align = ", n_32align); + TORCH_CHECK(b_qweight.size(1) == k, + "Shape mismatch: b_qweight.size(1) = ", b_qweight.size(1), + ", k = ", k); + + TORCH_CHECK(group_size == -1, "Currently only supports group_size = -1"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + const void* a_ptr = reinterpret_cast(a.data_ptr()); + const uint8_t* b_ptr = reinterpret_cast(b_qweight.data_ptr()); + const void* b_scale_ptr = reinterpret_cast(b_scales.data_ptr()); + const void* b_zero_ptr = nullptr; + if (b_qzeros.has_value()) { + b_zero_ptr = reinterpret_cast(b_qzeros.value().data_ptr()); + } + + auto c_options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + if (as_g_output_map.count(weight_name_pattern) == 0 or + as_g_output_map.at(weight_name_pattern).numel() < m * n) { + as_g_output_map[weight_name_pattern] = torch::empty({m, n}, c_options); + } + torch::Tensor tensor_to_reuse = as_g_output_map[weight_name_pattern]; + std::vector new_shape = {m, n}; + int64_t dim1_step = tensor_to_reuse.stride(1); + int64_t dim0_step = dim1_step * n; + std::vector new_stride = {dim0_step, dim1_step}; + torch::Tensor c = + tensor_to_reuse.as_strided(new_shape, new_stride).to(a.dtype()); + void* c_ptr = reinterpret_cast(c.data_ptr()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + + allspark::BlockTileSplitkParams fused_gemm_params; + + size_t ws_size = 0; + if (m > CUBLAS_M_THRESHOLD) { + ws_size = k * n * 2; // sizeof(f16)==2 + } else { + ws_size = allspark::allspark_qgemm_w8a16_perc_n32k16_ampere_workspace_size( + m, n, k, sm_count, fused_gemm_params); + } + + auto ws_options = torch::TensorOptions().dtype(at::kChar).device(a.device()); + if (as_g_workspace.numel() < + ws_size) { // ws_options: kChar, so numel() is bytes + as_g_workspace = torch::empty({long(ws_size)}, ws_options); + } + void* ws = reinterpret_cast(as_g_workspace.data_ptr()); + + if (a.dtype() == at::ScalarType::Half) { + allspark::allspark_qgemm_w8a16_perc_ampere<__half, uint8_t>( + reinterpret_cast(a_ptr), b_ptr, + reinterpret_cast(b_scale_ptr), + reinterpret_cast(b_zero_ptr), + reinterpret_cast<__half*>(c_ptr), m, n_32align, n, k, ws, + fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, + handle); + } else if (a.dtype() == at::ScalarType::BFloat16) { + allspark::allspark_qgemm_w8a16_perc_ampere<__nv_bfloat16, uint8_t>( + reinterpret_cast(a_ptr), b_ptr, + reinterpret_cast(b_scale_ptr), + reinterpret_cast(b_zero_ptr), + reinterpret_cast<__nv_bfloat16*>(c_ptr), m, n_32align, n, k, ws, + fused_gemm_params, group_size, CUBLAS_M_THRESHOLD, sm_version, stream, + handle); + } + + return c; +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm); +} \ No newline at end of file diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/quantization/gptq_allspark/allspark_repack.cu new file mode 100644 index 000000000000..82929c94ad8b --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_repack.cu @@ -0,0 +1,163 @@ +#include "allspark_utils.cuh" +#include +#include "core/registration.h" + +namespace allspark { + +// Rearrange B to facilitate Ampere Tensor Core load data +// reorder B from (K, N) to (N_32align / 4, K * 4) +// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0 +template +__global__ void __launch_bounds__(128) + rearrange_kn_weight_as_n32k16_order_ldg16_kernel( + const uint8_t* B, const FType* B_scale, const FType* B_zero, + uint8_t* B_result, FType* B_scale_result, FType* B_zero_result, + const int K, const int N, const int N_32align) { + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + + if (blockIdx.x != gridDim.x - 1) { + // Load B + // per block process 64(k) * 128(n) B elements + // per warp process 16(k) * 128 B elements + const int src_row_base_idx = + blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2; + const int src_col_idx = + blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16; + uint8_t B_frag[4][16]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2); + int src_offset = src_row_idx * N + src_col_idx; + bool guard = src_row_idx < K && src_col_idx < N; + ldg128_cg_0(*reinterpret_cast(B_frag[i]), + *(reinterpret_cast(B_frag[i]) + 1), + *(reinterpret_cast(B_frag[i]) + 2), + *(reinterpret_cast(B_frag[i]) + 3), B + src_offset, + guard); + } + + // reorder B + uint8_t B_reorder_frag[8][8]; +#pragma unroll + for (int i = 0; i < 4; ++i) { +#pragma unroll + for (int j = 0; j < 16; ++j) { + int dst_i = j % 8; + int dst_j = i + (j / 8) * 4; + B_reorder_frag[dst_i][dst_j] = B_frag[i][j]; + } + } + + // Store B + const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8; + const int dst_col_idx = + blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8; + for (int i = 0; i < 8; ++i) { + int dst_row_idx = dst_row_base_idx + i; + int dst_offset = dst_row_idx * K * 4 + dst_col_idx; + bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4); + if (guard) { + *reinterpret_cast(B_result + dst_offset) = + *reinterpret_cast(B_reorder_frag[i]); + } + } + } else { + // Load B_scale and B_zero + FType b_scale_reg, b_zero_reg; + int src_offset = blockIdx.y * 128 + threadIdx.x; + ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N); + if (B_zero != nullptr) + ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N); + int dst_offset = + blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8; + if (dst_offset < N_32align) { + B_scale_result[dst_offset] = b_scale_reg; + if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg; + } + } +} + +template +void rearrange_kn_weight_as_n32k16_order_ldg16( + const uint8_t* B, const FType* B_scale, const FType* B_zero, + uint8_t* B_result, FType* B_scale_result, FType* B_zero_result, + const int64_t K, const int64_t N, const int64_t N_32align, + cudaStream_t stream) { + if (N % 16 != 0 || K % 16 != 0) { + std::cerr << "Now only support N and K is multiples of 16" << std::endl; + } + const int BLOCK = 128; + int grid_x = (K + 64 - 1) / 64 + 1; + int grid_y = (N + 128 - 1) / 128; + dim3 grid(grid_x, grid_y); + + rearrange_kn_weight_as_n32k16_order_ldg16_kernel + <<>>(B, B_scale, B_zero, B_result, B_scale_result, + B_zero_result, K, N, N_32align); +} +} // namespace allspark + +void rearrange_kn_weight_as_n32k16_order( + torch::Tensor const& b_qweight, torch::Tensor const& b_scales, + c10::optional const& b_zeros, bool has_zp, + torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, + c10::optional const& b_zeros_reorder, const int64_t K, + const int64_t N, const int64_t N_32align) { + // Verify device and strides + TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); + TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(b_qweight_reorder.device().is_cuda(), + "b_qweight_reorder is not on GPU"); + TORCH_CHECK(b_qweight_reorder.is_contiguous(), + "b_qweight_reorder is not contiguous"); + + TORCH_CHECK(b_scales_reorder.device().is_cuda(), + "b_scales_reorder is not on GPU"); + TORCH_CHECK(b_scales_reorder.is_contiguous(), + "b_scales_reorder is not contiguous"); + + if (has_zp) { + TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous"); + + TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(), + "b_zeros_reorder is not on GPU"); + TORCH_CHECK(b_zeros_reorder.value().is_contiguous(), + "b_zeros_reorder is not contiguous"); + } + + const uint8_t* matB = reinterpret_cast(b_qweight.data_ptr()); + const void* b_scale = b_scales.data_ptr(); + const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr; + + uint8_t* matB_reorder = + reinterpret_cast(b_qweight_reorder.data_ptr()); + void* b_scale_reorder = b_scales_reorder.data_ptr(); + void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (b_scales.dtype() == at::ScalarType::Half) { + allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>( + matB, reinterpret_cast(b_scale), + reinterpret_cast(b_zero), matB_reorder, + reinterpret_cast<__half*>(b_scale_reorder), + reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream); + } else if (b_scales.dtype() == at::ScalarType::BFloat16) { + allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>( + matB, reinterpret_cast(b_scale), + reinterpret_cast(b_zero), matB_reorder, + reinterpret_cast<__nv_bfloat16*>(b_scale_reorder), + reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align, + stream); + } +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("rearrange_kn_weight_as_n32k16_order", + &rearrange_kn_weight_as_n32k16_order); +} diff --git a/csrc/quantization/gptq_allspark/allspark_utils.cuh b/csrc/quantization/gptq_allspark/allspark_utils.cuh new file mode 100644 index 000000000000..7aded9a17280 --- /dev/null +++ b/csrc/quantization/gptq_allspark/allspark_utils.cuh @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace allspark { + +#define CHECK_CUDA(cmd) \ + do { \ + cudaError_t cuda_status = cmd; \ + if (cuda_status != cudaSuccess) { \ + std::string err_str = cudaGetErrorString(cuda_status); \ + std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \ + << err_str; \ + exit(-1); \ + } \ + } while (0) + +#define CHECK_CUBLAS(cmd) \ + do { \ + cublasStatus_t cublas_status = cmd; \ + if (cublas_status != CUBLAS_STATUS_SUCCESS) { \ + std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \ + << cublas_status << std::endl; \ + exit(-1); \ + } \ + } while (0) + +template +struct SM8x_GEMM_W8A16_Splitk_Params { + const FType* A_ptr; + const QType* B_ptr; + const FType* B_scale_ptr; + const FType* B_zero_ptr; + FType* C_ptr; + int M; + int N; + int K; + int SplitK; + int GroupCnt; + int GroupSize; + FType* C_split_ptr; // for non-fused splitk reduce + float* C_tmp_ptr; // for fused splitk reduce + uint32_t* red_count_ptr; // for fused splitk reduce +}; + +struct alignas(16) BlockTileSplitkParams { + int Mtile; + int Ntile; + int SplitK; + bool EnableFuse; +}; + +template +__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C, + uint32_t n, uint32_t n_matrix, + uint32_t matrix_size) { + int idx = blockIdx.x * BLOCK + threadIdx.x; + + if (idx >= matrix_size) { + return; + } + + FType sum(0); + + int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix; + for (int i = 0; i < n_mat; ++i) { + sum += C_split[idx + i * matrix_size]; + } + + C[idx] = sum; +} + +template +void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m, + const uint32_t n, const uint32_t n_matrix, + cudaStream_t stream) { + const int BLOCK = 128; + uint32_t matrix_size = m * n; + int grid = (matrix_size + BLOCK - 1) / BLOCK; + + void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr; + + switch (n_matrix) { + case 4: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 5: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 6: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 7: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 8: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 9: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 10: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 11: + kernel = f16_gemm_splitk_reduce_kernel; + break; + case 12: + kernel = f16_gemm_splitk_reduce_kernel; + break; + default: + kernel = f16_gemm_splitk_reduce_kernel; + break; + } + + kernel<<>>(C_split, C, n, n_matrix, matrix_size); +} + +template +struct HalfType; +template <> +struct HalfType { + using T1 = __half; + using T2 = __half2; +}; +template <> +struct HalfType<__nv_bfloat16> { + using T1 = __nv_bfloat16; + using T2 = __nv_bfloat162; +}; + +// convert 64-bit pointer to 32-bit smem addr +__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) { + uint32_t addr; + asm("{.reg .u64 u64addr;\n" + " cvta.to.shared.u64 u64addr, %1;\n" + " cvt.u32.u64 %0, u64addr;}\n" + : "=r"(addr) + : "l"(smem_ptr)); + + return addr; +} + +template +__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) { + static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @!p mov.b16 %0, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n" +#else + " @p ld.global.ca.b16 {%0}, [%1];}\n" +#endif + : "=h"(reinterpret_cast(r0)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr, + bool guard) { + static_assert(sizeof(T) == 4, "ldg64_ca: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n" +#else + " @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n" +#endif + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3, + const void* ptr, bool guard) { + static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " @!p mov.b32 %0, 0;\n" + " @!p mov.b32 %1, 0;\n" + " @!p mov.b32 %2, 0;\n" + " @!p mov.b32 %3, 0;\n" +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \ + __CUDA_ARCH__ >= 750 + " @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n" +#else + " @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n" +#endif + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)), + "=r"(reinterpret_cast(r2)), + "=r"(reinterpret_cast(r3)) + : "l"(ptr), "r"((int)guard)); +} + +template +__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3, + const uint32_t addr) { + static_assert(sizeof(T) == 4, "lds128: invalid T"); + + asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(reg0)), + "=r"(reinterpret_cast(reg1)), + "=r"(reinterpret_cast(reg2)), + "=r"(reinterpret_cast(reg3)) + : "r"(addr)); +} + +template +__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2, + const T& r3, const void* ptr, + bool guard) { + static_assert(sizeof(T) == 4, "stg128: invalid T"); + + asm volatile( + "{.reg .pred p;\n" + " setp.ne.b32 p, %1, 0;\n" + " @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n" + : + : "l"(ptr), "r"((int)guard), "r"(reinterpret_cast(r0)), + "r"(reinterpret_cast(r1)), + "r"(reinterpret_cast(r2)), + "r"(reinterpret_cast(r3))); +} + +template +__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3, + const uint32_t& addr) { + static_assert(sizeof(T) == 4, "ldsm_4: invalid T"); +#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(r0)), + "=r"(reinterpret_cast(r1)), + "=r"(reinterpret_cast(r2)), + "=r"(reinterpret_cast(r3)) + : "r"(addr)); +#endif +} + +template +__device__ __forceinline__ void hmma16816_f32(float (&d)[4], + const uint32_t (&a)[4], + const uint32_t (&b)[2]); + +template <> +__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4], + const uint32_t (&a)[4], + const uint32_t (&b)[2]) { +#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n" + : "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); +#endif +} + +template <> +__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>( + float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) { +#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n" + : "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); +#endif +} + +template +__device__ __forceinline__ void cp_async(const uint32_t smem_addr, + const void* gmem_ptr, + const int src_in_bytes, bool guard) { + static_assert( + (SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16), + "Size is not supported"); +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile( + "{.reg.pred p;\n" + " setp.ne.b32 p, %4, 0;\n" + #if __CUDACC_VER_MINOR__ >= 4 + " @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n" + #else + " @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n" + #endif + ::"r"(smem_addr), + "l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard)); +#endif +} + +template +__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr, + const void* gmem_ptr, + const int src_in_bytes, + bool guard) { + static_assert( + (SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16), + "Size is not supported"); +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile( + "{.reg.pred p;\n" + " setp.ne.b32 p, %4, 0;\n" + #if __CUDACC_VER_MINOR__ >= 4 + " @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n" + #else + " @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n" + #endif + ::"r"(smem_addr), + "l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard)); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n"); +#endif +} + +template +__device__ __forceinline__ void cp_asyc_wait_group() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" : : "n"(N)); +#endif +} + +template +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata, + T* fdata); + +template <> +// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128 +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>( + const uint32_t& idata, __half2* fdata) { + uint32_t i10, i32; + asm volatile( + "prmt.b32 %0, %2, 0x64, 0x4140;" + "prmt.b32 %1, %2, 0x64, 0x4342;" + : "=r"(i10), "=r"(i32) + : "r"(idata)); + + static constexpr uint32_t MAGIC_NUM = 0x64806480; + fdata[0] = __hsub2(reinterpret_cast(i10), + reinterpret_cast(MAGIC_NUM)); + fdata[1] = __hsub2(reinterpret_cast(i32), + reinterpret_cast(MAGIC_NUM)); +} + +template <> +// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128 +// reference from marlin fast implementation +__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>( + const uint32_t& idata, __nv_bfloat162* fdata) { + float fp32_imd[4]; + uint32_t* fp32_imd_casted = reinterpret_cast(fp32_imd); + asm volatile( + "prmt.b32 %0, %4, 0x4B000000, 0x7650;" + "prmt.b32 %1, %4, 0x4B000000, 0x7651;" + "prmt.b32 %2, %4, 0x4B000000, 0x7652;" + "prmt.b32 %3, %4, 0x4B000000, 0x7653;" + : "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]), + "=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3]) + : "r"(idata)); + + fp32_imd[0] -= 8388736.f; + fp32_imd[1] -= 8388736.f; + fp32_imd[2] -= 8388736.f; + fp32_imd[3] -= 8388736.f; + + uint32_t* bf16_res = reinterpret_cast(fdata); + asm volatile( + "prmt.b32 %0, %2, %3, 0x7632;" + "prmt.b32 %1, %4, %5, 0x7632;" + : "=r"(bf16_res[0]), "=r"(bf16_res[1]) + : "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]), + "r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3])); +} + +static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(x); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); +} + +} // namespace allspark \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 72de2035d0c1..bf8526879d3f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -447,6 +447,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + +#ifndef USE_ROCM + // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel + ops.def( + "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " + "Tensor? b_zeros, " + "bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, " + "Tensor!? b_zeros_reorder, " + "int K, int N, int N_32align) -> ()"); + // conditionally compiled so impl in source file + + // AllSpark quantization ops + ops.def( + "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " + "Tensor? b_qzeros, " + "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " + "CUBLAS_M_THRESHOLD, " + "bool has_zp, bool n32k16_reorder, str weight_name_pattern) -> Tensor"); + // conditionally compiled so impl in source file +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/tests/kernels/test_allspark_gemm.py b/tests/kernels/test_allspark_gemm.py new file mode 100644 index 000000000000..597551ab5ea1 --- /dev/null +++ b/tests/kernels/test_allspark_gemm.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_AMPERE_N_ALIGN) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + + +def is_gptq_allspark_supported(min_capability: int, + max_capability: int) -> bool: + if not current_platform.is_cuda(): + return False + + capability = current_platform.get_device_capability() + assert capability is not None + + return capability.to_int() >= min_capability \ + and capability.to_int() <= max_capability + + +MNK_FACTORS = [ + (1, 4, 8), + (13, 17, 67), + (26, 37, 13), + (48, 16, 24), + (67, 13, 88), + (257, 13, 11), + (658, 13, 11), + (1033, 9, 17), +] + +DTYPES = [torch.float16, torch.bfloat16] +HAS_ZP_OPTS = [False, True] + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +def rand_data(shape, dtype=torch.float16): + return torch.randn(shape, dtype=dtype, device="cuda") + + +@pytest.mark.skipif( + not is_gptq_allspark_supported(80, 89), + reason="AllSpark Ampere kernel is not supported on this GPU type.") +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("group_size", [-1]) +@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): + m_factor, n_factor, k_factor = mnk_factors + m = m_factor + n = n_factor * ALLSPARK_AMPERE_N_ALIGN + k = k_factor * ALLSPARK_AMPERE_K_ALIGN + + input = rand_data((m, k), dtype=dtype) + weight = rand_data((k, n), dtype=dtype) + + # Quantize (and apply act_order if provided) + w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128, + group_size, has_zp) + + qw = qw.to(torch.uint8) + if has_zp: + zp = zp.to(dtype) + properties = torch.cuda.get_device_properties(qw.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + n_32align = (n + 32 - 1) // 32 * 32 + + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp) + opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order, + (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, + n_32align)) + + weight_name_pattern = f'model.layers.k{k}.m{m}.n{n}.qweight' + + opcheck(torch.ops._C.allspark_w8a16_gemm, + (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, + sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True, + weight_name_pattern), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) + output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, + n, group_size, sm_count, sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, True, weight_name_pattern) + + output_ref = torch.matmul(input, w_ref) + torch.cuda.synchronize() + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3306610ad800..0a34c434e9a5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -404,6 +404,22 @@ def machete_prepack_B_fake( memory_format=torch.contiguous_format) +if hasattr(torch.ops._C, "allspark_w8a16_gemm"): + + @register_fake("_C::allspark_w8a16_gemm") + def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, n32k16_reorder: bool, + weight_name_pattern: str) -> torch.Tensor: + m = a.size(0) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + if hasattr(torch.ops._C, "ggml_dequantize"): @register_fake("_C::ggml_dequantize") @@ -881,6 +897,69 @@ def scaled_fp8_quant( return output, scale +# gptq allspark +def allspark_repack_weight( + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + for Ampere W8A16 Fused Gemm kernel + + Args: + qweight: uint8 weight tensor, original k x n format. + scale: fp16/bf16 weight scale tensor, 1 x n format. + zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. + Must be provided for asymmetric quantization. + has_zp: if use symmetric quantization, has_zp = False. + if use asymmetric quantization, has_zp = True. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + rearranged weight, scale, and optionally zero_point. + """ + K = qweight.shape[0] + N = qweight.shape[1] + N_32align = (N + 32 - 1) // 32 * 32 + + qweight_reorder = torch.empty((N_32align, K), + device=qweight.device, + dtype=qweight.dtype) + scale_reorder = torch.empty((1, N_32align), + device=scale.device, + dtype=scale.dtype) + zero_point_reorder = None + if has_zp: + assert zero_point is not None, ( + "zero_point must be provided for asymmetric quantization.") + zero_point_reorder = torch.empty((1, N_32align), + device=zero_point.device, + dtype=zero_point.dtype) + + torch.ops._C.rearrange_kn_weight_as_n32k16_order( + qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, + zero_point_reorder, K, N, N_32align) + + return qweight_reorder, scale_reorder, zero_point_reorder + + +def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], n: int, + group_size: int, sm_count: int, sm_version: int, + CUBLAS_M_THRESHOLD: int, has_zp: bool, + n32k16_reorder: bool, + weight_name_pattern: str) -> torch.Tensor: + + return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, + n, group_size, sm_count, + sm_version, CUBLAS_M_THRESHOLD, + has_zp, n32k16_reorder, + weight_name_pattern) + + # int8 def scaled_int8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 521724765beb..d5fa496c7493 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -220,7 +220,8 @@ def __init__(self, self.input_size, self.output_size, self.params_dtype, - weight_loader=self.weight_loader) + weight_loader=self.weight_loader, + prefix=prefix) if bias: self.bias = Parameter( @@ -320,7 +321,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + prefix=prefix) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -1068,7 +1070,8 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), + prefix=prefix) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 38df09ff3937..7bbc12d2e778 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -148,11 +148,13 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, weight_loader=weight_loader) layer.register_parameter("weight_g_idx", weight_g_idx) + prefix = kwargs.get("prefix") self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", w_zp_param_name=None, - w_gidx_param_name="weight_g_idx") + w_gidx_param_name="weight_g_idx", + prefix=prefix) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9f960d9fd37f..e3ca798f8401 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -217,6 +217,7 @@ def create_weights( output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + prefix = extra_weight_attrs.get("prefix") mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), @@ -327,7 +328,8 @@ def create_weights( w_q_param_name="qweight", w_s_param_name="scales", w_zp_param_name="qzeros", - w_gidx_param_name="g_idx") + w_gidx_param_name="g_idx", + prefix=prefix) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index c06befaf3b5a..471d697f906c 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -39,7 +39,8 @@ def __init__(self, w_q_param_name: str, w_s_param_name: str, w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None) -> None: + w_gidx_param_name: Optional[str] = None, + prefix: Optional[str] = None) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -50,6 +51,7 @@ def __init__(self, assert w_gidx_param_name is not None self.w_zp_name = w_zp_param_name self.w_gidx_name = w_gidx_param_name + self.prefix = prefix @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index bcfdb1677716..520e1bc96721 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -3,6 +3,8 @@ from typing import List, Optional, Type import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 + AllSparkLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -16,6 +18,7 @@ # in priority/performance order (when available) _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, + AllSparkLinearKernel, MarlinLinearKernel, ExllamaLinearKernel, ] diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py new file mode 100644 index 000000000000..5827dde2ce8d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class AllSparkLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx: + return False, "Act reordering currently not supported by AllSpark" + + if c.zero_points: + return False, "Zero points currently not supported by AllSpark" + + return check_allspark_supported_dtype_shape( + c.partition_weight_shape[0], # in_features + c.partition_weight_shape[1], # out_features + c.group_size, + c.weight_type, + c.act_type) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + # prepare the parameters required for the kernel + properties = torch.cuda.get_device_properties(device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + gemm_args = {} + gemm_args['sm_count'] = sm_count + gemm_args['sm_version'] = sm_version + + self.gemm_args = gemm_args + + # transform param weight, scale + old_weight_param = getattr(layer, self.w_q_name) + old_scale_param = getattr(layer, self.w_s_name) + + assert isinstance(old_weight_param, BasevLLMParameter) + permute_param_layout_(old_weight_param, + input_dim=0, + output_dim=1, + packed_dim=0) + + assert isinstance(old_scale_param, BasevLLMParameter) + permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) + + # unpack weight from K / 4 x N int32 to K x N uint8 + new_weight_param = torch.nn.Parameter(old_weight_param.data, + requires_grad=False) + new_weight_param.data = new_weight_param.data.t().contiguous().view( + dtype=torch.uint8) + new_weight_param.data = new_weight_param.data.t().contiguous() + + new_scale_param = torch.nn.Parameter(old_scale_param.data, + requires_grad=False) + + # reorder K x N weight as N32K16 format for Ampere W8A16 + new_weight_param.data, new_scale_param.data, _ = \ + ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, + c.zero_points) + + replace_parameter(layer, self.w_q_name, new_weight_param.data) + replace_parameter(layer, self.w_s_name, new_scale_param.data) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + gemm_args = self.gemm_args + w_q, w_s, _, _ = self._get_weight_params(layer) + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + assert self.prefix is not None + weight_name = self.prefix if self.prefix is not None else "" + + if re.search(r'\.\d+\.', weight_name): + weight_name_pattern = re.sub(r'\.\d+\.', '.', weight_name, count=1) + else: + weight_name_pattern = weight_name + + output = ops.allspark_w8a16_gemm( + a=reshaped_x, + b_qweight=w_q, + b_scales=w_s, + b_qzeros=None, + n=c.partition_weight_shape[1], + group_size=c.group_size, + sm_count=gemm_args['sm_count'], + sm_version=gemm_args['sm_version'], + CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp=c.zero_points, + n32k16_reorder=True, + weight_name_pattern=weight_name_pattern) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py new file mode 100644 index 000000000000..97860765a9e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024 +ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] +ALLSPARK_AMPERE_N_ALIGN = 16 +ALLSPARK_AMPERE_K_ALIGN = 16 + + +def check_allspark_supported_dtype_shape(input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype): + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + # For Ampere GPU + if device_capability >= 80 and device_capability < 90: + if group_size != -1: + return False, \ + "For Ampere GPU, AllSpark does not support group_size "\ + f"= {group_size}. Only group_size = -1 are supported." + + if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: + return False, "For Ampere GPU, AllSpark does not support "\ + f"quant type ({weight_dtype}). Only quant type "\ + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." + + if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: + return False, \ + "AllSpark needs input_size_per_partition % "\ + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ + "for Ampere GPU optimized kernels." + + if act_dtype != torch.float16 and act_dtype != torch.bfloat16: + return False, \ + "AllSpark only supports act_dtype = float16 or bfloat16,"\ + f"for Ampere GPU, but got act_dtype = {act_dtype}." + else: + return False, "AllSpark currently does not support "\ + f"device_capability = {device_capability}." + + return True, None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f65dfc3cb329..25a83bda2b8c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -266,7 +266,8 @@ def __init__(self, self.embedding_dim, self.num_embeddings_padded, params_dtype=params_dtype, - weight_loader=self.weight_loader) + weight_loader=self.weight_loader, + prefix=prefix) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, From 1d2905971183c8439459b7917682d6acb40308f5 Mon Sep 17 00:00:00 2001 From: wyj371990 Date: Wed, 26 Feb 2025 19:56:35 +0800 Subject: [PATCH 2/3] [Kernel]: remove prefix in AllSpark kernel Signed-off-by: wyj371990 --- benchmarks/kernels/benchmark_marlin.py | 4 +--- .../gptq_allspark/allspark_qgemm_w8a16.cu | 19 +++---------------- csrc/torch_bindings.cpp | 3 +-- tests/kernels/test_allspark_gemm.py | 7 ++----- vllm/_custom_ops.py | 10 ++++------ vllm/model_executor/layers/linear.py | 9 +++------ .../schemes/compressed_tensors_wNa16.py | 4 +--- .../layers/quantization/gptq_marlin.py | 4 +--- .../kernels/mixed_precision/MPLinearKernel.py | 4 +--- .../kernels/mixed_precision/allspark.py | 12 +----------- .../layers/vocab_parallel_embedding.py | 3 +-- 11 files changed, 19 insertions(+), 60 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 8fd3a8d06f67..21ef491294e3 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -140,8 +140,6 @@ def bench_run(results: List[benchmark.Measurement], model: str, "sm_version": sm_version if as_supported_case else None, "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, - "weight_name_pattern": - f'model.layers.k{size_k}.m{size_m}.n{size_n}.qweight', # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -210,7 +208,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True, weight_name_pattern)", # noqa: E501 + "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index eccc755f90b5..c4ed98ca64f8 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -3,7 +3,6 @@ #include "core/registration.h" #include -std::map as_g_output_map; // cache for 1 layer at::Tensor as_g_workspace; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -12,8 +11,7 @@ torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, torch::Tensor const& b_scales, c10::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, - int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder, - std::string const& weight_name_pattern) { + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { TORCH_CHECK_NOT_IMPLEMENTED( false, "allspark_w8a16_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -919,8 +917,7 @@ torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, torch::Tensor const& b_scales, c10::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, - int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder, - std::string const& weight_name_pattern) { + int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); @@ -960,17 +957,7 @@ torch::Tensor allspark_w8a16_gemm( } auto c_options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - if (as_g_output_map.count(weight_name_pattern) == 0 or - as_g_output_map.at(weight_name_pattern).numel() < m * n) { - as_g_output_map[weight_name_pattern] = torch::empty({m, n}, c_options); - } - torch::Tensor tensor_to_reuse = as_g_output_map[weight_name_pattern]; - std::vector new_shape = {m, n}; - int64_t dim1_step = tensor_to_reuse.stride(1); - int64_t dim0_step = dim1_step * n; - std::vector new_stride = {dim0_step, dim1_step}; - torch::Tensor c = - tensor_to_reuse.as_strided(new_shape, new_stride).to(a.dtype()); + torch::Tensor c = torch::empty({m, n}, c_options); void* c_ptr = reinterpret_cast(c.data_ptr()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bf8526879d3f..0b0334f84efe 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -463,8 +463,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, " "Tensor? b_qzeros, " "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " - "CUBLAS_M_THRESHOLD, " - "bool has_zp, bool n32k16_reorder, str weight_name_pattern) -> Tensor"); + "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); // conditionally compiled so impl in source file #endif } diff --git a/tests/kernels/test_allspark_gemm.py b/tests/kernels/test_allspark_gemm.py index 597551ab5ea1..896e0265738b 100644 --- a/tests/kernels/test_allspark_gemm.py +++ b/tests/kernels/test_allspark_gemm.py @@ -84,17 +84,14 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align)) - weight_name_pattern = f'model.layers.k{k}.m{m}.n{n}.qweight' - opcheck(torch.ops._C.allspark_w8a16_gemm, (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, - sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True, - weight_name_pattern), + sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True), test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - has_zp, True, weight_name_pattern) + has_zp, True) output_ref = torch.matmul(input, w_ref) torch.cuda.synchronize() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0a34c434e9a5..dfbfe207b7fa 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -414,8 +414,8 @@ def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, sm_count: torch.SymInt, sm_version: torch.SymInt, CUBLAS_M_THRESHOLD: torch.SymInt, - has_zp: bool, n32k16_reorder: bool, - weight_name_pattern: str) -> torch.Tensor: + has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: m = a.size(0) return torch.empty((m, n), device=a.device, dtype=a.dtype) @@ -950,14 +950,12 @@ def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, b_qzeros: Optional[torch.Tensor], n: int, group_size: int, sm_count: int, sm_version: int, CUBLAS_M_THRESHOLD: int, has_zp: bool, - n32k16_reorder: bool, - weight_name_pattern: str) -> torch.Tensor: + n32k16_reorder: bool) -> torch.Tensor: return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, - has_zp, n32k16_reorder, - weight_name_pattern) + has_zp, n32k16_reorder) # int8 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d5fa496c7493..521724765beb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -220,8 +220,7 @@ def __init__(self, self.input_size, self.output_size, self.params_dtype, - weight_loader=self.weight_loader, - prefix=prefix) + weight_loader=self.weight_loader) if bias: self.bias = Parameter( @@ -321,8 +320,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -1070,8 +1068,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7bbc12d2e778..38df09ff3937 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -148,13 +148,11 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, weight_loader=weight_loader) layer.register_parameter("weight_g_idx", weight_g_idx) - prefix = kwargs.get("prefix") self.kernel = kernel_type(mp_linear_kernel_config, w_q_param_name="weight_packed", w_s_param_name="weight_scale", w_zp_param_name=None, - w_gidx_param_name="weight_g_idx", - prefix=prefix) + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e3ca798f8401..9f960d9fd37f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -217,7 +217,6 @@ def create_weights( output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") - prefix = extra_weight_attrs.get("prefix") mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), @@ -328,8 +327,7 @@ def create_weights( w_q_param_name="qweight", w_s_param_name="scales", w_zp_param_name="qzeros", - w_gidx_param_name="g_idx", - prefix=prefix) + w_gidx_param_name="g_idx") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 471d697f906c..c06befaf3b5a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -39,8 +39,7 @@ def __init__(self, w_q_param_name: str, w_s_param_name: str, w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None, - prefix: Optional[str] = None) -> None: + w_gidx_param_name: Optional[str] = None) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -51,7 +50,6 @@ def __init__(self, assert w_gidx_param_name is not None self.w_zp_name = w_zp_param_name self.w_gidx_name = w_gidx_param_name - self.prefix = prefix @abstractmethod def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py index 5827dde2ce8d..56fdd6a18e0d 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import re from typing import Optional, Tuple import torch @@ -97,14 +96,6 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) - assert self.prefix is not None - weight_name = self.prefix if self.prefix is not None else "" - - if re.search(r'\.\d+\.', weight_name): - weight_name_pattern = re.sub(r'\.\d+\.', '.', weight_name, count=1) - else: - weight_name_pattern = weight_name - output = ops.allspark_w8a16_gemm( a=reshaped_x, b_qweight=w_q, @@ -116,8 +107,7 @@ def apply_weights(self, sm_version=gemm_args['sm_version'], CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp=c.zero_points, - n32k16_reorder=True, - weight_name_pattern=weight_name_pattern) + n32k16_reorder=True) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 25a83bda2b8c..f65dfc3cb329 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -266,8 +266,7 @@ def __init__(self, self.embedding_dim, self.num_embeddings_padded, params_dtype=params_dtype, - weight_loader=self.weight_loader, - prefix=prefix) + weight_loader=self.weight_loader) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, From 9d319eba55d990aad2a27f4d68cb438c627ec095 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 28 Feb 2025 02:26:35 +0000 Subject: [PATCH 3/3] Remove asserts Signed-off-by: mgoin --- tests/quantization/test_compressed_tensors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c187b4c7ed99..b9b2b634e0bb 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -215,8 +215,6 @@ def check_model(model): assert qkv_proj.scheme.group_size == (-1 if group is None else group) - assert qkv_proj.weight_packed.dtype is torch.int32 - assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.scheme.pack_factor == pack_factor llm.apply_model(check_model)