From dacdf8cf649564e7efce991f643cce449a460639 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 09:56:00 +0800 Subject: [PATCH 01/35] fp4 marlin kernel Signed-off-by: Jinzhen Lin --- csrc/core/scalar_type.hpp | 3 + csrc/quantization/gptq_marlin/dequant.h | 320 ++++++++++++++---- .../gptq_marlin/generate_kernels.py | 10 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 26 +- .../gptq_marlin/marlin_template.h | 57 ++-- .../kernels/quantization/test_marlin_gemm.py | 101 ++---- .../layers/quantization/modelopt.py | 31 +- .../layers/quantization/utils/marlin_utils.py | 11 +- .../quantization/utils/marlin_utils_fp4.py | 288 ++++++++++++++++ .../quantization/utils/marlin_utils_fp8.py | 25 +- 10 files changed, 691 insertions(+), 181 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index c2ae554c9f8e..97c3e6cf53d0 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1fn = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 3c0d77ac345d..29b76a4156cc 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -1,3 +1,67 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accurary issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accurary issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multipiler)) + = mul(bit_op(weight), scale_factor * multipiler) + +where `scale_factor * multipiler` can be computed at weight loading. + +*/ #include "marlin_dtypes.cuh" @@ -27,7 +91,7 @@ __device__ inline uint32_t prmt(uint32_t a) { return res; } -template +template __device__ inline void dequant(int q, scalar_t2* frag_b); // @@ -40,7 +104,20 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -62,7 +139,12 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -84,7 +166,7 @@ __device__ inline void dequant(int q, half2* frag_b) { } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -96,39 +178,40 @@ __device__ inline void dequant( int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); // clang-format on - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&SUB)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; + dequant(q, frag_b); +} - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - // clang-format on +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; + static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&SUB)); } // @@ -140,7 +223,7 @@ __device__ inline void dequant( // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> -__device__ inline void dequant(int q, +__device__ inline void dequant(int q, half2* frag_b) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; @@ -149,33 +232,40 @@ __device__ inline void dequant(int q, uint32_t lo = prmt(q); uint32_t hi = prmt(q); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} - frag_b[0] = __hsub2(*reinterpret_cast(&lo), +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - frag_b[0] = __hsub2(*reinterpret_cast(&lo), + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -200,7 +290,7 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = @@ -225,22 +315,30 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant(int q, +__device__ inline void dequant(int q, half2* frag_b) { // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -248,28 +346,36 @@ __device__ inline void dequant(int q, const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 + constexpr int MASK = 0x7F007F00; - // Extract and shift FP8 values to BF16 format + // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; // Construct and apply exponent bias constexpr int BIAS_OFFSET = @@ -281,9 +387,85 @@ __device__ inline void dequant( __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } #endif diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 8b4b951f3d86..fb4dcd5280d4 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] @@ -40,7 +43,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -73,6 +76,9 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" ^ group_blocks != 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 02527a481661..3f991180f5d6 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -258,6 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for fp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -314,6 +315,23 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + // We currently have 4-bit models only with group_blocks == 4 #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ @@ -366,6 +384,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) + FP4_GET_IF(vllm::kFE2M1fn) + BIGGROUP_GET_IF(vllm::kFE4M3fn) ACT_GET_IF(vllm::kU4B8) @@ -447,7 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, + q_type == vllm::kFE4M3fn || + q_type == vllm::kFE2M1fn, "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", q_type.str()); @@ -774,7 +795,8 @@ torch::Tensor gptq_marlin_gemm( "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, + b_q_type == vllm::kFE4M3fn || + b_q_type == vllm::kFE2M1fn, "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", b_q_type.str()); diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index ca05b8a25f86..9e8a6caa34cf 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -325,6 +325,13 @@ __global__ void Marlin( static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || \ + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || \ + has_zp && !is_zp_float && !std::is_same::value || \ + has_zp && !is_zp_float && !(w_type == vllm::kU8); + constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); @@ -567,7 +574,8 @@ __global__ void Marlin( if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -681,7 +689,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -1065,22 +1073,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1110,7 +1103,7 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; @@ -1125,7 +1118,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1138,6 +1134,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1145,7 +1146,7 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1153,7 +1154,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr(!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1408,7 +1409,7 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } @@ -1488,7 +1489,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr(!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters); @@ -1563,7 +1566,7 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1573,7 +1576,7 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1597,7 +1600,7 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index c125e0b5ec75..b79ee837fd4c 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -20,6 +20,8 @@ MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_scales, query_marlin_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -190,9 +192,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) +@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES + [16]) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -210,6 +211,7 @@ def test_gptq_marlin_gemm( use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] size_m = m_factor size_k = k_chunk * k_factor @@ -220,6 +222,8 @@ def test_gptq_marlin_gemm( return if group_size == size_k: return + if has_zp: + return if size_k % group_size != 0: return @@ -227,7 +231,15 @@ def test_gptq_marlin_gemm( a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - if quant_type == scalar_types.float8_e4m3fn: + if quant_type == scalar_types.float4_e2m1fn: + if group_size != 16 or act_order: + return + w_ref, marlin_q_w, marlin_s = rand_marlin_weight_fp4_like( + b_weight.T, group_size) + g_idx = None + sort_indices = None + marlin_zp = None + elif quant_type == scalar_types.float8_e4m3fn: if group_size not in [-1, 128]: return if act_order: @@ -236,21 +248,23 @@ def test_gptq_marlin_gemm( b_weight.T, group_size) g_idx = None sort_indices = None + marlin_zp = None + elif has_zp: + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size) + g_idx = None + sort_indices = None else: + if group_size == 16: + return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, quant_type, group_size, act_order) - - marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + marlin_zp = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck( - torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.gptq_marlin_gemm( a_input, None, @@ -339,67 +353,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_awq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - use_fp32_reduce, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) - - g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - is_k_full = True - - workspace = marlin_make_workspace_new(a_input.device) - - output = ops.gptq_marlin_gemm( - a_input, - None, - marlin_q_w, - marlin_s, - marlin_zp, - g_idx, - sort_indices, - workspace, - quant_type, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - is_k_full=is_k_full, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - output_ref = torch.matmul(a_input, w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 828447dd1019..9a97879213a6 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,6 +15,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -194,7 +197,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -264,9 +267,15 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + if not self.cutlass_nvfp4_supported: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and above.") + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights( self, @@ -378,12 +387,28 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) + if self.use_marlin: + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled + def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.use_marlin: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + output_dtype = x.dtype # for input only the contracting dimension has a constraint. diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index a2b1b7cb0e1d..cdc383fddbea 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -33,7 +33,7 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, + has_zp: bool = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): @@ -45,6 +45,13 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point return [scalar_types.uint4] @@ -52,7 +59,7 @@ def query_marlin_supported_quant_types( # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: - res += [scalar_types.float8_e4m3fn] + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1fn] return res diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 000000000000..f566a3ed582c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +def is_fp4_marlin_supported(): + return current_platform.has_device_capability(80) + + +def fp4_fused_exponent_bias_into_scales(scales): + fp4_exponent = 2 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + extra_scale_factor: int = 1, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + if extra_scale_factor != 1 and size_n > size_k: + reshaped_x = reshaped_x * extra_scale_factor + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if extra_scale_factor != 1 and size_n <= size_k: + output = output * extra_scale_factor + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + scales = layer.weight_scale.to(torch.float32) * \ + layer.weight_scale_2.to(torch.float32) + + layer.marlin_extra_scale_factor = 1 + if scales.max() >= 2: + # We would scaling the `scales` tensors later, it would overflow + # if the value is greater than or equal to 2 ** fp4_exponent = 4. + # So we first divide the scales by a certain value to avoid overflow. + # Afterwards, we will multiply the computation results of + # the Marlin kernel by this value. + s = 2**(scales.max() / 2).log2().ceil().item() + layer.marlin_extra_scale_factor = s + scales = scales / s + + marlin_scales = marlin_permute_scales(s=scales.T.to(param_dtype), + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + marlin_scales = fp4_fused_exponent_bias_into_scales(marlin_scales) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=8) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + if layer.weight_block_size is None: + group_size = -1 + else: + group_size = layer.weight_block_size[1] + + for name in ["w13", "w2"]: + if name + "_weight_scale" in dir(layer): + new_name = name + "_weight_scale" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + elif name + "_weight_scale_inv" in dir(layer): + new_name = name + "_weight_scale_inv" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if layer.weight_block_size is None: + if scales.nelement() == e: + # tensor-wise quantization -> channel-wise quantization + # (e, 1, 1) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2) + elif scales.nelement() > e and scales.nelement() != e * size_n: + assert (e * size_n) % scales.nelement() == 0 + s_size = scales.nelement() // e + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (e, 1, s_size) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, s_size) + scales = scales.repeat_interleave(size_n // s_size, 2) + else: + # channel-wise quantization + # (e, 1, size_n) + scales = scales.view(e, 1, size_n) + else: + # block-wise quantization -> group-wise quantization + # (e, size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (e, size_k // block_size[1], size_n) + block_n = layer.weight_block_size[0] + scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) + # size_n may not divisible by block_size[0] + scales = scales[..., :size_n].contiguous() + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i], + size_k=size_k, + size_n=size_n, + group_size=group_size) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp4_fused_exponent_bias_into_scales(scales) + scales = torch.nn.Parameter(scales, requires_grad=False) + + setattr(layer, name + "_weight_scale", scales) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view( + size_n, size_k) * scales.repeat_interleave(group_size, 1) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = fp4_fused_exponent_bias_into_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 1e0078e246be..d516db537687 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -19,6 +19,20 @@ def is_fp8_marlin_supported(): return current_platform.has_device_capability(80) +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( input: torch.Tensor, weight: torch.Tensor, @@ -132,8 +146,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.T.contiguous() block_n = layer.weight_block_size[0] - scales = scales.T.repeat_interleave(block_n, 1) + scales = scales.repeat_interleave(block_n, 1) # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] @@ -141,6 +157,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, size_k=part_size_k, size_n=part_size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) @@ -239,8 +256,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, # block-wise quantization -> group-wise quantization # (e, size_k // block_size[1], ceil(size_n / block_size[0])) # =>(repeat)=> (e, size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.permute(0, 2, 1) block_n = layer.weight_block_size[0] - scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) + scales = scales.repeat_interleave(block_n, 2) # size_n may not divisible by block_size[0] scales = scales[..., :size_n].contiguous() @@ -252,6 +271,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp8_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) @@ -301,5 +321,6 @@ def marlin_quant_fp8_torch(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) return weight_ref.T, marlin_qweight, marlin_scales From e2c0ad38b250694df66ab9cdf94047936face57e Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 10:19:11 +0800 Subject: [PATCH 02/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/generate_kernels.py | 2 +- csrc/quantization/gptq_marlin/marlin_template.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index fb4dcd5280d4..40b73f02a9c5 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -33,7 +33,7 @@ # we don't add it to reduce wheel size. SCALAR_TYPES = [ "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kFE2M1fn" ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index 9e8a6caa34cf..b16ad0da7d32 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -1118,7 +1118,7 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1f.id()) { + if constexpr (w_type_id == vllm::kFE2M1fn.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; } else if constexpr (w_type.size_bits() == 4) { From 0d5368bfed6207887415879aa868b6b7c4af8356 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 10:45:45 +0800 Subject: [PATCH 03/35] fix Signed-off-by: Jinzhen Lin --- .../gptq_marlin/generate_kernels.py | 2 +- .../gptq_marlin/marlin_template.h | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 40b73f02a9c5..759eb972fded 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -77,7 +77,7 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if scalar_type == "vllm::kFE2M1f" ^ group_blocks != 1: + if (scalar_type == "vllm::kFE2M1f") ^ (group_blocks != 1): continue k_blocks = thread_configs[0] // 16 diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index b16ad0da7d32..b5abb94603c3 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -325,12 +325,13 @@ __global__ void Marlin( static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; - constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || \ - w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details - constexpr bool dequant_skip_flop = !is_int_type || \ - has_zp && !is_zp_float && !std::is_same::value || \ - has_zp && !is_zp_float && !(w_type == vllm::kU8); + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); @@ -1146,7 +1147,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1154,7 +1156,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr(!dequant_skip_flop && has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1409,7 +1411,8 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } @@ -1489,7 +1492,7 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - if constexpr(!dequant_skip_flop) { + if constexpr (!dequant_skip_flop) { fetch_col_scale_to_shared(); } } @@ -1566,7 +1569,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1576,7 +1580,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1600,7 +1605,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { From 8d51e3248f3f04f3e57b727a9854ac81762f5b6a Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 10:51:29 +0800 Subject: [PATCH 04/35] fix Signed-off-by: Jinzhen Lin --- csrc/core/scalar_type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 97c3e6cf53d0..91193a6abc97 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -334,7 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; -static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat4_e2m1f = kFE2M1fn; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; From c879e997245b4fdef4ccea38d685c915cd45e55f Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 12:24:06 +0800 Subject: [PATCH 05/35] fix format Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/dequant.h | 60 ++++++++++--------- .../layers/quantization/utils/marlin_utils.py | 2 +- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 29b76a4156cc..b626871492cf 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -57,7 +57,7 @@ We cannot fused byte_perm with scaling. scale(flop(bit_op(weight))) = scale(mul(bit_op(weight), multipiler)) - = mul(bit_op(weight), scale_factor * multipiler) + = mul(bit_op(weight), scale_factor * multipiler) where `scale_factor * multipiler` can be computed at weight loading. @@ -91,7 +91,8 @@ __device__ inline uint32_t prmt(uint32_t a) { return res; } -template +template __device__ inline void dequant(int q, scalar_t2* frag_b); // @@ -104,7 +105,8 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { const int MASK = 0x000f000f; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. @@ -117,7 +119,8 @@ __device__ inline void dequant(int q, half2* frag } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -139,12 +142,14 @@ __device__ inline void dequant(int q, half2* fra } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -189,10 +194,8 @@ __device__ inline void dequant( static constexpr uint32_t SUB = 0x43084308; - frag_b[0] = __hsub2(frag_b[0], - *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], - *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> @@ -208,10 +211,8 @@ __device__ inline void dequant( static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hsub2(frag_b[0], - *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], - *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // @@ -224,7 +225,7 @@ __device__ inline void dequant( // template <> __device__ inline void dequant(int q, - half2* frag_b) { + half2* frag_b) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -237,8 +238,8 @@ __device__ inline void dequant(int q, } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; @@ -249,12 +250,14 @@ __device__ inline void dequant(int q, } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; @@ -315,8 +318,8 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { // Constants for FP8 (E4M3) and FP16 formats constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; @@ -333,8 +336,8 @@ __device__ inline void dequant(int q, } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { dequant(q, frag_b); // Constants for FP8 (E4M3) and FP16 formats @@ -359,7 +362,7 @@ __device__ inline void dequant( constexpr int MASK = 0x7F007F00; - // Extract and shift FP8 values to BF16 format + // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); q <<= 8; int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); @@ -391,10 +394,9 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } - template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; @@ -411,8 +413,8 @@ __device__ inline void dequant(int q, } template <> -__device__ inline void dequant(int q, - half2* frag_b) { +__device__ inline void dequant( + int q, half2* frag_b) { dequant(q, frag_b); // Constants for FP4 (E2M1) and FP16 formats diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index cdc383fddbea..c9bebf35555b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -33,7 +33,7 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool = None, + has_zp: Optional[bool] = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): From bb547a60ac7a1161e5be8249c9d89f3e76d3e504 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 12:43:39 +0800 Subject: [PATCH 06/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/generate_kernels.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 759eb972fded..6432f334f708 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -77,7 +77,10 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if (scalar_type == "vllm::kFE2M1f") ^ (group_blocks != 1): + if scalar_type == "vllm::kFE2M1fn" and group_blocks != 1: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1fn" and group_blocks == 1: continue k_blocks = thread_configs[0] // 16 From 4dddda5d44417bf306e365bb02c11447f3916e14 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 12:51:59 +0800 Subject: [PATCH 07/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/dequant.h | 10 +++++----- csrc/quantization/gptq_marlin/gptq_marlin.cu | 14 ++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index b626871492cf..96afe4586fff 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -34,7 +34,7 @@ If has float zero points: = bit_op(weight) - (fzp + bias) where the `fzp + bias` can be computed at weight loading. But this -may have accurary issue, so we should not use this in most cases. +may have accuracy issue, so we should not use this in most cases. If has not zero points: @@ -43,7 +43,7 @@ If has not zero points: = scale(bit_op(weight)) - scale(bias) = fma(bit_op(weight), scale_factor, scale(bias)) -where the `scale(bias)` can be cached. But this may have accurary issue, +where the `scale(bias)` can be cached. But this may have accuracy issue, so we should not use this in most cases. @@ -56,10 +56,10 @@ We cannot fused byte_perm with scaling. ## FP4/FP8 => FP16/BF16 scale(flop(bit_op(weight))) - = scale(mul(bit_op(weight), multipiler)) - = mul(bit_op(weight), scale_factor * multipiler) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) -where `scale_factor * multipiler` can be computed at weight loading. +where `scale_factor * multiplier` can be computed at weight loading. */ diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 3f991180f5d6..48d9665adb5c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -315,13 +315,13 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) #define FP4_GET_IF(W_TYPE) \ @@ -467,8 +467,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || - q_type == vllm::kFE2M1fn, + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1fn, "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", q_type.str()); @@ -795,8 +794,7 @@ torch::Tensor gptq_marlin_gemm( "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || - b_q_type == vllm::kFE2M1fn, + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1fn, "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", b_q_type.str()); From 9aac76acdb30595feb0738dd7eb4dea1d826ca2a Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 16:38:28 +0800 Subject: [PATCH 08/35] kFE2M1fn -> kFE2M1f Signed-off-by: Jinzhen Lin --- csrc/core/scalar_type.hpp | 4 ++-- csrc/quantization/gptq_marlin/dequant.h | 12 ++++++------ csrc/quantization/gptq_marlin/generate_kernels.py | 6 +++--- csrc/quantization/gptq_marlin/gptq_marlin.cu | 6 +++--- csrc/quantization/gptq_marlin/marlin_template.h | 2 +- tests/kernels/quantization/test_marlin_gemm.py | 2 +- .../layers/quantization/utils/marlin_utils.py | 2 +- .../layers/quantization/utils/marlin_utils_fp4.py | 2 +- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 91193a6abc97..d0f85e23609b 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,7 +315,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); -static inline constexpr auto kFE2M1fn = +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); @@ -334,7 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; -static inline constexpr auto kFloat4_e2m1f = kFE2M1fn; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 96afe4586fff..a318f7583c63 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -395,7 +395,7 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, half2* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; @@ -413,9 +413,9 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, half2* frag_b) { - dequant(q, frag_b); + dequant(q, frag_b); // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; @@ -431,7 +431,7 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; @@ -449,9 +449,9 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( +__device__ inline void dequant( int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); + dequant(q, frag_b); // Constants for FP4 (E2M1) and BF16 formats constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 6432f334f708..4ac7121ab4e1 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -33,7 +33,7 @@ # we don't add it to reduce wheel size. SCALAR_TYPES = [ "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1fn" + "vllm::kFE2M1f" ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] @@ -77,10 +77,10 @@ def generate_new_kernels(): if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 - if scalar_type == "vllm::kFE2M1fn" and group_blocks != 1: + if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: continue # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1fn" and group_blocks == 1: + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: continue k_blocks = thread_configs[0] // 16 diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 48d9665adb5c..99b913166f8d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -384,7 +384,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, COMMON_GET_IF(vllm::kU4B8) COMMON_GET_IF(vllm::kU8B128) - FP4_GET_IF(vllm::kFE2M1fn) + FP4_GET_IF(vllm::kFE2M1f) BIGGROUP_GET_IF(vllm::kFE4M3fn) @@ -467,7 +467,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1fn, + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", q_type.str()); @@ -794,7 +794,7 @@ torch::Tensor gptq_marlin_gemm( "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1fn, + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " "has_zp = False. Got = ", b_q_type.str()); diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index b5abb94603c3..2fd6d79bb6bf 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -1119,7 +1119,7 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type_id == vllm::kFE2M1fn.id()) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { b_quant_1 = frag_b_quant[k2][0][j]; b_quant_0 = b_quant_1 << 8; } else if constexpr (w_type.size_bits() == 4) { diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index b79ee837fd4c..734e5ecac760 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -231,7 +231,7 @@ def test_gptq_marlin_gemm( a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - if quant_type == scalar_types.float4_e2m1fn: + if quant_type == scalar_types.float4_e2m1f: if group_size != 16 or act_order: return w_ref, marlin_q_w, marlin_s = rand_marlin_weight_fp4_like( diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index c9bebf35555b..b1b352c4bdac 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -59,7 +59,7 @@ def query_marlin_supported_quant_types( # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: - res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1fn] + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] return res diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index f566a3ed582c..2ebd8f6d70cd 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -66,7 +66,7 @@ def apply_fp4_marlin_linear( g_idx=None, perm=None, workspace=workspace, - b_q_type=scalar_types.float4_e2m1fn, + b_q_type=scalar_types.float4_e2m1f, size_m=reshaped_x.size(0), size_n=size_n, size_k=size_k, From 8392d738135d236b9eb3d6d4dcb9f58032de6947 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 16:52:34 +0800 Subject: [PATCH 09/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/dequant.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index a318f7583c63..864900059467 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -395,8 +395,8 @@ __device__ inline void dequant( } template <> -__device__ inline void dequant( - int q, half2* frag_b) { +__device__ inline void dequant(int q, + half2* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; From 5050d4b549b9fa716d966683af6df7ccff0c0e4a Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 20:48:39 +0800 Subject: [PATCH 10/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 6 +++--- tests/kernels/quantization/test_marlin_gemm.py | 13 +++++++++++-- .../layers/quantization/utils/marlin_utils.py | 1 + .../layers/quantization/utils/marlin_utils_fp4.py | 3 +++ 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 99b913166f8d..f04eba7dbdd3 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -258,7 +258,7 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) - // FP4: cases for fp4(e2m1) (group_blocks == 1) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -468,7 +468,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, } else { TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " "has_zp = False. Got = ", q_type.str()); } @@ -795,7 +795,7 @@ torch::Tensor gptq_marlin_gemm( } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128 or float8_e4m3fn when " + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " "has_zp = False. Got = ", b_q_type.str()); } diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 734e5ecac760..65c69bb2f692 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -21,7 +21,7 @@ marlin_make_workspace_new, marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like) + rand_marlin_weight_fp4_like, FP4_MARLIN_SUPPORTED_GROUP_SIZES) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -193,7 +193,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES + [16]) +@pytest.mark.parametrize("group_size", + set(MARLIN_SUPPORTED_GROUP_SIZES + + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -265,6 +267,13 @@ def test_gptq_marlin_gemm( workspace = marlin_make_workspace_new(w_ref.device) + opcheck( + torch.ops._C.gptq_marlin_gemm, + (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, + workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], + a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) + output = ops.gptq_marlin_gemm( a_input, None, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b1b352c4bdac..0f5f91b03861 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -45,6 +45,7 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # (has_zp is None) means return both quant_types that has zp and has not zp if has_zp is None: types0 = query_marlin_supported_quant_types(False, include_fp_type, device_capability) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 2ebd8f6d70cd..532db90fd036 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -12,6 +12,9 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + logger = init_logger(__name__) From 02576a9a47eefaecf618900283cf91569d857173 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 20:50:50 +0800 Subject: [PATCH 11/35] fix comment Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/utils/marlin_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0f5f91b03861..2e7425e6166c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -45,7 +45,9 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] - # (has_zp is None) means return both quant_types that has zp and has not zp + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both if has_zp is None: types0 = query_marlin_supported_quant_types(False, include_fp_type, device_capability) From 49978ad36f572df793e2db6e986a18a78e507c81 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 22:00:17 +0800 Subject: [PATCH 12/35] fix Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/quantization/modelopt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9a97879213a6..5593e305a006 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -407,6 +407,7 @@ def apply( workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, + extra_scale_factor=layer.marlin_extra_scale_factor, bias=bias) output_dtype = x.dtype From af12b22b5664d9a2a72dfbbd51897c6b06637bdd Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 22:37:11 +0800 Subject: [PATCH 13/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 14 ++++++++------ tests/kernels/quantization/test_marlin_gemm.py | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index f04eba7dbdd3..3eaa1189facb 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -466,11 +466,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, q_type == vllm::kU4 || q_type == vllm::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -795,7 +796,8 @@ torch::Tensor gptq_marlin_gemm( } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " "has_zp = False. Got = ", b_q_type.str()); } diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 65c69bb2f692..51efa9806506 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -21,7 +21,7 @@ marlin_make_workspace_new, marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_fp4_like, FP4_MARLIN_SUPPORTED_GROUP_SIZES) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -193,9 +193,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) -@pytest.mark.parametrize("group_size", - set(MARLIN_SUPPORTED_GROUP_SIZES + - FP4_MARLIN_SUPPORTED_GROUP_SIZES)) +@pytest.mark.parametrize( + "group_size", + set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) From fa0d098cd3c8bd826f59eb808a5884d0fe9dea93 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 6 May 2025 22:48:22 +0800 Subject: [PATCH 14/35] fix Signed-off-by: Jinzhen Lin --- .../model_executor/layers/quantization/utils/marlin_utils_fp4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 532db90fd036..bfc0a5e2200d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -14,7 +14,6 @@ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] - logger = init_logger(__name__) From e6265a61f009e137a5591707d68aad62a34f3404 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 11:37:06 +0800 Subject: [PATCH 15/35] update Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/dequant.h | 36 +++++++++ csrc/quantization/gptq_marlin/gptq_marlin.cu | 48 +++++++++--- csrc/quantization/gptq_marlin/kernel.h | 4 +- .../marlin_fp8_scales_preproocess.cu | 75 +++++++++++++++++++ .../gptq_marlin/marlin_template.h | 45 ++++++++--- csrc/torch_bindings.cpp | 7 +- vllm/_custom_ops.py | 11 ++- .../layers/quantization/modelopt.py | 2 +- .../quantization/utils/marlin_utils_fp4.py | 52 ++++++------- 9 files changed, 230 insertions(+), 50 deletions(-) create mode 100644 csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index 864900059467..b49720a446ad 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -470,6 +470,42 @@ __device__ inline void dequant( frag_b[0] = __hmul2(frag_b[0], bias_reg); } +template +__device__ inline void dequant_fp8_scales( + int q, scalar_t2* frag_b); + + +template <> +__device__ inline void dequant_fp8_scales( + int q, half2* frag_b) { + + int Out1 = (q & 0xFF00FF00) >> 1;; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + + +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + #endif } // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 3eaa1189facb..78dc90ff6146 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -315,14 +315,14 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ #define FP4_GET_IF(W_TYPE) \ FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ @@ -453,7 +453,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, @@ -504,6 +504,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -622,7 +623,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, use_fp32_reduce, max_shared_mem_new); // clang-format on @@ -638,6 +639,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor gptq_marlin_gemm( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -780,6 +782,17 @@ torch::Tensor gptq_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(b_q_type != vllm::kFE2M1f, + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -842,19 +855,34 @@ torch::Tensor gptq_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + scales_ptr, global_scale.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index eb2700c95e86..82f7df07ea79 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -10,7 +10,9 @@ #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t* __restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, \ const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \ int prob_k, int lda, int *locks, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem diff --git a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu new file mode 100644 index 000000000000..7c4e933cba23 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu @@ -0,0 +1,75 @@ + +#include +#include + + +__global__ void marlin_fp8_scales_preprocess_kernel(int4* __restrict__ in_ptr, +int4* __restrict__ out_ptr, int64_t s_size) { + + // convert subnormal fp8_e4m3 value to fp8_e5m3_val + // #0bEEEEEMMM // subnormal_e4m3_val = e5m3_val + const uint8_t subnormal_val_map[9] = { + 0b00000000, // 0 / 2 = 0 + 0b00110000, // 1 / 8 * (2 ** -6) = 1.00 * (2 ** (6 - 15)) + 0b00111000, // 2 / 8 * (2 ** -6) = 1.00 * (2 ** (7 - 15)) + 0b00111100, // 3 / 8 * (2 ** -6) = 1.50 * (2 ** (7 - 15)) + 0b01000000, // 4 / 8 * (2 ** -6) = 1.00 * (2 ** (8 - 15)) + 0b01000010, // 5 / 8 * (2 ** -6) = 1.25 * (2 ** (8 - 15)) + 0b01000100, // 6 / 8 * (2 ** -6) = 1.50 * (2 ** (8 - 15)) + 0b01000110, // 7 / 8 * (2 ** -6) = 1.75 * (2 ** (8 - 15)) + }; + + int offset = blockIdx.x * blockDim.x; + + // Note that after the conversion, + // the first bit of all values (except 0.0) is 1 + auto process_val = [&](uint8_t val) { + if (val == 0) return 0; + + // normalized value case + // (x | 0x80): set the top bit of exponent to 1 + // so that we have less exponent bias with fp16/bf16 + // (x - 8): divide the fp8 value by 2 + // to avoid the value become NaN after dequantization + // when x = *reinterpret_cast(&fp8_val) + // (x - 8 * y) means the exponent is decreased by y, + // which corresponds to dividing the fp8 value by 2 ** y + else if (val >= 8) return (val | 0x80) - 8; + + // subnormal value (all exponent bits is 0) + // (x - 8 * 8): to match the exponent bias used by normalized numbers + // (x - 8): same with normalized value case + else return (subnormal_val_map[val] | 0x80) - 8 * (8 + 1); + }; + + for (int i = offset + threadIdx.x; i < s_size / 16; i += blockDim.x) { + int4 val = in_ptr[i]; + uint8_t* vals = reinterpret_cast(&val); + + #pragma unroll + for (int j = 0; j < 16; j++) vals[j] = process_val(vals[j]); + + out_ptr[i] = *reinterpret_cast(vals); + } +}; + + +torch::Tensor marlin_fp8_scales_preprocess(torch::Tensor scales) { + TORCH_CHECK(scales.device().is_cuda(), "scales is not on GPU"); + + int dev = scales.get_device(); + torch::Tensor out_scales = torch::empty_like(scales); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + marlin_fp8_scales_preprocess_kernel<<<256, 512, 0, stream>>>( + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(out_scales.data_ptr()), + scales.nbytes() + ); + + return out_scales; +} + + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_fp8_scales_preprocess", &marlin_fp8_scales_preprocess); +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index 2fd6d79bb6bf..a0ed6ba9400c 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -290,11 +290,12 @@ __global__ void Marlin( const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -333,6 +334,13 @@ __global__ void Marlin( has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); + scalar_t2 global_scale; + + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + constexpr bool has_act_order = group_blocks == 0; constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); @@ -489,7 +497,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -548,7 +556,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -900,8 +908,15 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr(w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[ + s_sh_rd + cur_group_id / 2 * 2 * s_sh_stride + + cur_group_id % 2]; + } } } @@ -1111,6 +1126,14 @@ __global__ void Marlin( } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1416,6 +1439,10 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } + if constexpr(w_type == vllm::kFE2M1f) { + res = __hmul2(res, global_scale); + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f59b42d88c61..4346f24e1abc 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -292,13 +292,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " + "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " "Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); // conditionally compiled so impl registration is in source file + // process marlin fp8 scales (used for W4A16-FP4) + ops.def("marlin_fp8_scales_preprocess(Tensor scales) -> Tensor", + {stride_tag}); + // conditionally compiled so impl registration is in source file + // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 44377ccb2959..5662a0737524 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -328,6 +328,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -778,6 +779,10 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +def marlin_fp8_scales_preprocess(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.marlin_fp8_scales_preprocess(scales) + + def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: @@ -810,6 +815,7 @@ def gptq_marlin_gemm(a: torch.Tensor, c: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_zeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -822,8 +828,9 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type.id, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, global_scale, + b_zeros, g_idx, perm, workspace, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5593e305a006..df897c0c3eca 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -404,10 +404,10 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - extra_scale_factor=layer.marlin_extra_scale_factor, bias=bias) output_dtype = x.dtype diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index bfc0a5e2200d..f35c0f87d413 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -39,10 +39,10 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, workspace: torch.Tensor, size_n: int, size_k: int, - extra_scale_factor: int = 1, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP4 hardware support, we can leverage the @@ -57,13 +57,11 @@ def apply_fp4_marlin_linear( device=input.device, dtype=input.dtype) - if extra_scale_factor != 1 and size_n > size_k: - reshaped_x = reshaped_x * extra_scale_factor - output = ops.gptq_marlin_gemm(a=reshaped_x, c=None, b_q_weight=weight, b_scales=weight_scale, + global_scale=weight_scale_2, b_zeros=None, g_idx=None, perm=None, @@ -75,8 +73,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if extra_scale_factor != 1 and size_n <= size_k: - output = output * extra_scale_factor if bias is not None: output.add_(bias) # In-place add @@ -116,26 +112,30 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - scales = layer.weight_scale.to(torch.float32) * \ - layer.weight_scale_2.to(torch.float32) - - layer.marlin_extra_scale_factor = 1 - if scales.max() >= 2: - # We would scaling the `scales` tensors later, it would overflow - # if the value is greater than or equal to 2 ** fp4_exponent = 4. - # So we first divide the scales by a certain value to avoid overflow. - # Afterwards, we will multiply the computation results of - # the Marlin kernel by this value. - s = 2**(scales.max() / 2).log2().ceil().item() - layer.marlin_extra_scale_factor = s - scales = scales / s - - marlin_scales = marlin_permute_scales(s=scales.T.to(param_dtype), - size_k=part_size_k, - size_n=part_size_n, - group_size=16) - marlin_scales = fp4_fused_exponent_bias_into_scales(marlin_scales) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + + weight_scale = ops.marlin_fp8_scales_preprocess(weight_scale) + weight_scale = weight_scale.view(weight_scale.size(0) // 2, 2, -1, 8) + weight_scale = weight_scale.permute(0, 2, 1, 3).reshape( + weight_scale.size(0) * 2, -1) + weight_scale = weight_scale.view(-1, 4)[:, [0, 2, 1, 3]].view( + weight_scale.size(0), -1).to(torch.float8_e4m3fn) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2 + if param_dtype == torch.half: + weight_scale_2 = weight_scale_2 * (2.0**7) + elif param_dtype == torch.bfloat16: + weight_scale_2 = weight_scale_2 * (2.0**119) + + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2.to(param_dtype), + requires_grad=False) + + return def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: From fe3ea6e0faaa237987a7cf566bf62abef61ac3a4 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 11:41:21 +0800 Subject: [PATCH 16/35] fix for fp8 Signed-off-by: Jinzhen Lin --- .../layers/quantization/utils/marlin_utils_fp8.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index d516db537687..0e35a2963b4a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -271,7 +271,6 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - scales = fp8_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales) @@ -321,6 +320,5 @@ def marlin_quant_fp8_torch(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) return weight_ref.T, marlin_qweight, marlin_scales From e6047e50c74828efc0ffb6806a1fba52ec1b9a39 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 11:49:16 +0800 Subject: [PATCH 17/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- tests/kernels/quantization/test_marlin_gemm.py | 5 +++++ vllm/model_executor/layers/quantization/hqq_marlin.py | 4 +++- .../model_executor/layers/quantization/utils/marlin_utils.py | 2 ++ .../layers/quantization/utils/marlin_utils_fp8.py | 1 + 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 78dc90ff6146..18a5e758d17d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -789,7 +789,7 @@ torch::Tensor gptq_marlin_gemm( "global_scale can only be used for float4_e2m1f."); } else { global_scale = torch::empty({0}, options); - TORCH_CHECK(b_q_type != vllm::kFE2M1f, + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); } diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 51efa9806506..825a1b95dd4c 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -241,6 +241,7 @@ def test_gptq_marlin_gemm( g_idx = None sort_indices = None marlin_zp = None + marlin_s2 = None elif quant_type == scalar_types.float8_e4m3fn: if group_size not in [-1, 128]: return @@ -251,6 +252,7 @@ def test_gptq_marlin_gemm( g_idx = None sort_indices = None marlin_zp = None + marlin_s2 = None elif has_zp: if group_size == 16: return @@ -258,12 +260,14 @@ def test_gptq_marlin_gemm( b_weight, quant_type, group_size) g_idx = None sort_indices = None + marlin_s2 = None else: if group_size == 16: return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, quant_type, group_size, act_order) marlin_zp = None + marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) @@ -279,6 +283,7 @@ def test_gptq_marlin_gemm( None, marlin_q_w, marlin_s, + marlin_s2, marlin_zp, g_idx, sort_indices, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 7bd398137e02..856dac9d3163 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -304,8 +304,10 @@ def apply( marlin_out = ops.gptq_marlin_gemm( x, + None, layer.marlin_qweight, scales, + None, zeros, layer.g_idx, layer.g_idx_sort_indices, @@ -315,7 +317,7 @@ def apply( self.output_size_per_partition, self.input_size_per_partition, True, # is_k_full - True, # has_zp + False, # use atomic add True, # use 32-bit reduce True, # use float zp ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 2e7425e6166c..89268ef7a38b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -404,6 +404,7 @@ def apply_gptq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, @@ -449,6 +450,7 @@ def apply_awq_marlin_linear( None, weight, weight_scale, + None, weight_zp, g_idx, g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 0e35a2963b4a..30ca778d843e 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -58,6 +58,7 @@ def apply_fp8_marlin_linear( c=None, b_q_weight=weight, b_scales=weight_scale, + global_scale=None, b_zeros=None, g_idx=None, perm=None, From dd53ce9bd4297e90aa58e386240fd8dc80c76ba6 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 11:55:04 +0800 Subject: [PATCH 18/35] fix Signed-off-by: Jinzhen Lin --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b3bfe0af7f5..8888ca7a5278 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -352,6 +352,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" From 810c95aed15c5de5ac982fba2065dd58a7732f37 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 11:56:30 +0800 Subject: [PATCH 19/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu index 7c4e933cba23..50b3549f7463 100644 --- a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu +++ b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu @@ -1,6 +1,7 @@ #include #include +#include "core/registration.h" __global__ void marlin_fp8_scales_preprocess_kernel(int4* __restrict__ in_ptr, From 0f071834bcf93f60f882fbb83daee699e8847a72 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 12:10:35 +0800 Subject: [PATCH 20/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/marlin_template.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index a0ed6ba9400c..624b1b1e029d 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -580,7 +580,16 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && @@ -904,7 +913,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; @@ -914,8 +923,7 @@ __global__ void Marlin( } else { reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast(sh_s_stage)[ - s_sh_rd + cur_group_id / 2 * 2 * s_sh_stride + - cur_group_id % 2]; + s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } From 7eb3f9bb8127c4048cea32f8862c5a94b0a0ede1 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 12:35:10 +0800 Subject: [PATCH 21/35] fix Signed-off-by: Jinzhen Lin --- .../quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu | 4 ++-- .../layers/quantization/utils/marlin_utils_fp4.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu index 50b3549f7463..e0ffc2c7b8f5 100644 --- a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu +++ b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu @@ -62,8 +62,8 @@ torch::Tensor marlin_fp8_scales_preprocess(torch::Tensor scales) { torch::Tensor out_scales = torch::empty_like(scales); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); marlin_fp8_scales_preprocess_kernel<<<256, 512, 0, stream>>>( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(out_scales.data_ptr()), + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(out_scales.data_ptr()), scales.nbytes() ); diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index f35c0f87d413..05576a329ef1 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -73,7 +73,6 @@ def apply_fp4_marlin_linear( use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce) - if bias is not None: output.add_(bias) # In-place add @@ -118,12 +117,12 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_n=part_size_n, group_size=16) - weight_scale = ops.marlin_fp8_scales_preprocess(weight_scale) weight_scale = weight_scale.view(weight_scale.size(0) // 2, 2, -1, 8) weight_scale = weight_scale.permute(0, 2, 1, 3).reshape( weight_scale.size(0) * 2, -1) weight_scale = weight_scale.view(-1, 4)[:, [0, 2, 1, 3]].view( weight_scale.size(0), -1).to(torch.float8_e4m3fn) + weight_scale = ops.marlin_fp8_scales_preprocess(weight_scale) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) weight_scale_2 = layer.weight_scale_2 From ed1db371d17a928f64ebe5e15dd4b341076c3750 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 13:15:45 +0800 Subject: [PATCH 22/35] fix test Signed-off-by: Jinzhen Lin --- tests/kernels/moe/test_moe.py | 2 +- .../kernels/quantization/test_marlin_gemm.py | 3 +-- .../quantization/utils/marlin_utils_fp4.py | 27 ++++++++++++++----- .../quantization/utils/marlin_utils_fp8.py | 4 ++- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index abf3e3667a75..31611c0c4199 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -380,7 +380,7 @@ def test_fused_marlin_moe( sort_indices1_l.append(sort_indices1) else: w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) + w1[i], group_size, True) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 825a1b95dd4c..14bc09c95fe9 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -236,12 +236,11 @@ def test_gptq_marlin_gemm( if quant_type == scalar_types.float4_e2m1f: if group_size != 16 or act_order: return - w_ref, marlin_q_w, marlin_s = rand_marlin_weight_fp4_like( + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( b_weight.T, group_size) g_idx = None sort_indices = None marlin_zp = None - marlin_s2 = None elif quant_type == scalar_types.float8_e4m3fn: if group_size not in [-1, 128]: return diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 05576a329ef1..466230fb66bf 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -251,6 +251,9 @@ def rand_marlin_weight_fp4_like(weight, group_size): device = weight.device scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + fp4_weight = torch.randint(0, 256, (size_n, size_k // 2), dtype=torch.uint8, @@ -268,8 +271,9 @@ def rand_marlin_weight_fp4_like(weight, group_size): weight_ref = torch.cat( [fp4_weight_part_2.unsqueeze(2), - fp4_weight_part_1.unsqueeze(2)], 2).view( - size_n, size_k) * scales.repeat_interleave(group_size, 1) + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), @@ -279,11 +283,20 @@ def rand_marlin_weight_fp4_like(weight, group_size): num_bits=4, ) - marlin_scales = marlin_permute_scales(s=scales.T, + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size) - - marlin_scales = fp4_fused_exponent_bias_into_scales(marlin_scales) - - return weight_ref.T, marlin_qweight, marlin_scales + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1).to(torch.float8_e4m3fn) + marlin_scales = ops.marlin_fp8_scales_preprocess(marlin_scales) + + if weight.dtype == torch.half: + global_scale = global_scale * (2.0**7) + elif weight.dtype == torch.bfloat16: + global_scale = global_scale * (2.0**119) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 30ca778d843e..287a14e4da1c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -293,7 +293,7 @@ def pack_fp8_to_int32(fp8_tensor: torch.Tensor, return int32_tensor.T.contiguous() if size_k_first else int32_tensor -def marlin_quant_fp8_torch(weight, group_size): +def marlin_quant_fp8_torch(weight, group_size, is_moe=False): size_n, size_k = weight.shape device = weight.device @@ -321,5 +321,7 @@ def marlin_quant_fp8_torch(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) + if not is_moe: + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) return weight_ref.T, marlin_qweight, marlin_scales From 168fb3ecfba4041dcc5afced5e20a4a083aa5312 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 17:01:02 +0800 Subject: [PATCH 23/35] fix Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/dequant.h | 16 +-- csrc/quantization/gptq_marlin/gptq_marlin.cu | 29 +++-- csrc/quantization/gptq_marlin/kernel.h | 17 ++- .../marlin_fp8_scales_preproocess.cu | 118 +++++++++--------- .../gptq_marlin/marlin_template.h | 33 ++--- tests/kernels/moe/test_moe.py | 2 +- .../kernels/quantization/test_marlin_gemm.py | 14 ++- vllm/_custom_ops.py | 8 +- .../layers/quantization/hqq_marlin.py | 2 +- 9 files changed, 119 insertions(+), 120 deletions(-) diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h index b49720a446ad..ae0d6c0f2002 100644 --- a/csrc/quantization/gptq_marlin/dequant.h +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -471,15 +471,12 @@ __device__ inline void dequant( } template -__device__ inline void dequant_fp8_scales( - int q, scalar_t2* frag_b); - +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); template <> -__device__ inline void dequant_fp8_scales( - int q, half2* frag_b) { - - int Out1 = (q & 0xFF00FF00) >> 1;; +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; q <<= 8; int Out2 = (q & 0xFF00FF00) >> 1; @@ -488,10 +485,9 @@ __device__ inline void dequant_fp8_scales( frag_b[0] = *reinterpret_cast(&Out2); }; - template <> -__device__ inline void dequant_fp8_scales( - int q, nv_bfloat162* frag_b) { +__device__ inline void dequant_fp8_scales(int q, + nv_bfloat162* frag_b) { constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 18a5e758d17d..4a242f2050d5 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -315,14 +315,14 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) #define FP4_GET_IF(W_TYPE) \ FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ @@ -453,9 +453,9 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, } template -void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void* s2, - void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, int lda, void* workspace, +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, + int prob_m, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k_init, @@ -880,12 +880,11 @@ torch::Tensor gptq_marlin_gemm( marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - scales_ptr, global_scale.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else { diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h index 82f7df07ea79..f92056589d20 100644 --- a/csrc/quantization/gptq_marlin/kernel.h +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -7,15 +7,14 @@ #include "marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, \ - const uint16_t* __restrict__ scale2_ptr, \ - const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \ - int prob_k, int lda, int *locks, bool use_atomic_add, \ - bool use_fp32_reduce, int max_shared_mem +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template #include "core/registration.h" - __global__ void marlin_fp8_scales_preprocess_kernel(int4* __restrict__ in_ptr, -int4* __restrict__ out_ptr, int64_t s_size) { - - // convert subnormal fp8_e4m3 value to fp8_e5m3_val - // #0bEEEEEMMM // subnormal_e4m3_val = e5m3_val - const uint8_t subnormal_val_map[9] = { - 0b00000000, // 0 / 2 = 0 - 0b00110000, // 1 / 8 * (2 ** -6) = 1.00 * (2 ** (6 - 15)) - 0b00111000, // 2 / 8 * (2 ** -6) = 1.00 * (2 ** (7 - 15)) - 0b00111100, // 3 / 8 * (2 ** -6) = 1.50 * (2 ** (7 - 15)) - 0b01000000, // 4 / 8 * (2 ** -6) = 1.00 * (2 ** (8 - 15)) - 0b01000010, // 5 / 8 * (2 ** -6) = 1.25 * (2 ** (8 - 15)) - 0b01000100, // 6 / 8 * (2 ** -6) = 1.50 * (2 ** (8 - 15)) - 0b01000110, // 7 / 8 * (2 ** -6) = 1.75 * (2 ** (8 - 15)) - }; - - int offset = blockIdx.x * blockDim.x; - - // Note that after the conversion, - // the first bit of all values (except 0.0) is 1 - auto process_val = [&](uint8_t val) { - if (val == 0) return 0; - - // normalized value case - // (x | 0x80): set the top bit of exponent to 1 - // so that we have less exponent bias with fp16/bf16 - // (x - 8): divide the fp8 value by 2 - // to avoid the value become NaN after dequantization - // when x = *reinterpret_cast(&fp8_val) - // (x - 8 * y) means the exponent is decreased by y, - // which corresponds to dividing the fp8 value by 2 ** y - else if (val >= 8) return (val | 0x80) - 8; - - // subnormal value (all exponent bits is 0) - // (x - 8 * 8): to match the exponent bias used by normalized numbers - // (x - 8): same with normalized value case - else return (subnormal_val_map[val] | 0x80) - 8 * (8 + 1); - }; - - for (int i = offset + threadIdx.x; i < s_size / 16; i += blockDim.x) { - int4 val = in_ptr[i]; - uint8_t* vals = reinterpret_cast(&val); - - #pragma unroll - for (int j = 0; j < 16; j++) vals[j] = process_val(vals[j]); - - out_ptr[i] = *reinterpret_cast(vals); - } + int4* __restrict__ out_ptr, + int64_t s_size) { + // convert subnormal fp8_e4m3 value to fp8_e5m3_val + // #0bEEEEEMMM // subnormal_e4m3_val = e5m3_val + const uint8_t subnormal_val_map[9] = { + 0b00000000, // 0 / 2 = 0 + 0b00110000, // 1 / 8 * (2 ** -6) = 1.00 * (2 ** (6 - 15)) + 0b00111000, // 2 / 8 * (2 ** -6) = 1.00 * (2 ** (7 - 15)) + 0b00111100, // 3 / 8 * (2 ** -6) = 1.50 * (2 ** (7 - 15)) + 0b01000000, // 4 / 8 * (2 ** -6) = 1.00 * (2 ** (8 - 15)) + 0b01000010, // 5 / 8 * (2 ** -6) = 1.25 * (2 ** (8 - 15)) + 0b01000100, // 6 / 8 * (2 ** -6) = 1.50 * (2 ** (8 - 15)) + 0b01000110, // 7 / 8 * (2 ** -6) = 1.75 * (2 ** (8 - 15)) + }; + + int offset = blockIdx.x * blockDim.x; + + // Note that after the conversion, + // the first bit of all values (except 0.0) is 1 + auto process_val = [&](uint8_t val) { + if (val == 0) return 0; + + // normalized value case + // (x | 0x80): set the top bit of exponent to 1 + // so that we have less exponent bias with fp16/bf16 + // (x - 8): divide the fp8 value by 2 + // to avoid the value become NaN after dequantization + // when x = *reinterpret_cast(&fp8_val) + // (x - 8 * y) means the exponent is decreased by y, + // which corresponds to dividing the fp8 value by 2 ** y + else if (val >= 8) + return (val | 0x80) - 8; + + // subnormal value (all exponent bits is 0) + // (x - 8 * 8): to match the exponent bias used by normalized numbers + // (x - 8): same with normalized value case + else + return (subnormal_val_map[val] | 0x80) - 8 * (8 + 1); + }; + + for (int i = offset + threadIdx.x; i < s_size / 16; i += blockDim.x) { + int4 val = in_ptr[i]; + uint8_t* vals = reinterpret_cast(&val); + +#pragma unroll + for (int j = 0; j < 16; j++) vals[j] = process_val(vals[j]); + + out_ptr[i] = *reinterpret_cast(vals); + } }; - torch::Tensor marlin_fp8_scales_preprocess(torch::Tensor scales) { - TORCH_CHECK(scales.device().is_cuda(), "scales is not on GPU"); + TORCH_CHECK(scales.device().is_cuda(), "scales is not on GPU"); - int dev = scales.get_device(); - torch::Tensor out_scales = torch::empty_like(scales); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - marlin_fp8_scales_preprocess_kernel<<<256, 512, 0, stream>>>( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(out_scales.data_ptr()), - scales.nbytes() - ); + int dev = scales.get_device(); + torch::Tensor out_scales = torch::empty_like(scales); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + marlin_fp8_scales_preprocess_kernel<<<256, 512, 0, stream>>>( + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(out_scales.data_ptr()), + scales.nbytes()); - return out_scales; + return out_scales; } - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("marlin_fp8_scales_preprocess", &marlin_fp8_scales_preprocess); } diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h index 624b1b1e029d..c49898210336 100644 --- a/csrc/quantization/gptq_marlin/marlin_template.h +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -290,12 +290,13 @@ __global__ void Marlin( const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 only) - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -556,7 +557,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == vllm::kFE2M1f ? 2 : 1) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -913,17 +915,18 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr(w_type_id != vllm::kFE2M1f.id()) { + if constexpr (w_type_id != vllm::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[ - s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1138,8 +1141,10 @@ __global__ void Marlin( int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); } // We have the m dimension as the inner loop in order to encourage overlapping @@ -1447,7 +1452,7 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } - if constexpr(w_type == vllm::kFE2M1f) { + if constexpr (w_type == vllm::kFE2M1f) { res = __hmul2(res, global_scale); } diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 31611c0c4199..10d0e6581413 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -421,7 +421,7 @@ def test_fused_marlin_moe( sort_indices2_l.append(sort_indices2) else: w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) + w2[i], group_size, True) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 14bc09c95fe9..52507b375c27 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -270,12 +270,12 @@ def test_gptq_marlin_gemm( workspace = marlin_make_workspace_new(w_ref.device) - opcheck( - torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck(torch.ops._C.gptq_marlin_gemm, + (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, + sort_indices, workspace, quant_type.id, a_input.shape[0], + b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, + use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, @@ -418,6 +418,7 @@ def test_hqq_marlin_gemm( None, marlin_w_q, marlin_s, + None, marlin_zp, g_idx, g_idx_sort_indices, @@ -530,6 +531,7 @@ def test_marlin_gemm_subset_input(): None, marlin_q_w, marlin_s, + None, marlin_zp, g_idx, sort_indices, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5662a0737524..9d68ef4ca062 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -828,10 +828,10 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, global_scale, - b_zeros, g_idx, perm, workspace, - b_q_type.id, - size_m, size_n, size_k, is_k_full, + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 856dac9d3163..e7511f330ea7 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -317,7 +317,7 @@ def apply( self.output_size_per_partition, self.input_size_per_partition, True, # is_k_full - False, # use atomic add + False, # use atomic add True, # use 32-bit reduce True, # use float zp ) From 25531ebabf3ecaa51770c23e7fdfbcbacbefc2e8 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 17:06:20 +0800 Subject: [PATCH 24/35] fix Signed-off-by: Jinzhen Lin --- csrc/torch_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4346f24e1abc..9797c786c9fc 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -292,8 +292,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " - "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, " - "Tensor? perm_or_none, Tensor workspace, int b_q_type, " + "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); From ed95abb2e30fe56cbfcc386a316039ac370adb62 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 20:10:54 +0800 Subject: [PATCH 25/35] fp4 moe marlin Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/generate_kernels.py | 13 +- csrc/moe/marlin_moe_wna16/kernel.h | 23 ++-- csrc/moe/marlin_moe_wna16/marlin_template.h | 124 ++++++++++++------ csrc/moe/marlin_moe_wna16/ops.cu | 86 +++++++++--- csrc/moe/torch_bindings.cpp | 3 +- tests/kernels/moe/test_moe.py | 68 +++++++--- vllm/_custom_ops.py | 11 +- .../layers/fused_moe/fused_marlin_moe.py | 14 +- .../quantization/utils/marlin_utils_fp8.py | 4 +- 9 files changed, 245 insertions(+), 101 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 902bcd9dfd21..15f008d4f61e 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -31,7 +31,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -39,7 +42,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -72,6 +75,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index c40c33d01f37..537282aba8c8 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,17 +7,18 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index c9e199bcea1f..752454d37644 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -301,9 +301,11 @@ __global__ void Marlin( int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens @@ -341,6 +343,16 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); @@ -348,7 +360,8 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -493,6 +506,11 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; if constexpr (has_zp) { @@ -606,7 +624,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -664,7 +682,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -688,10 +707,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -801,7 +830,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -1021,12 +1050,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1199,22 +1235,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1244,13 +1265,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1259,7 +1290,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1272,6 +1306,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1279,7 +1318,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1287,7 +1327,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1554,10 +1594,15 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + res = __hmul2(res, global_scale); + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1648,7 +1693,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters, i); @@ -1737,7 +1784,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1747,7 +1795,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1771,7 +1820,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 00b4e934cc39..c293b121310b 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -46,7 +46,7 @@ __global__ void permute_cols_kernel( const int32_t* __restrict__ sorted_token_ids_ptr, const int32_t* __restrict__ expert_ids_ptr, const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, - int size_k, int top_k) {}; + int size_k, int top_k){}; } // namespace marlin @@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ @@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, BIGGROUP_GET_IF(vllm::kFE4M3fn) + FP4_GET_IF(vllm::kFE2M1f) + ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) @@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, @@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, bool m_block_size_8 = moe_block_size == 8; if (has_zp) { - TORCH_CHECK(q_type == vllm::kU4, - "q_type must be u4 when has_zp = True. Got = ", q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); @@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm( if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", b_q_type.str()); } @@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), @@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 2a8b9bb39caa..810026d034c0 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 10d0e6581413..ad452d2d899f 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -18,6 +18,8 @@ fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -293,11 +295,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, @pytest.mark.parametrize("topk", [2, 3]) @pytest.mark.parametrize("ep_size", [1, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) +@pytest.mark.parametrize("group_size", [-1, 16, 32, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("quant_type", [ scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn + scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f ]) @pytest.mark.parametrize("is_k_full", [True, False]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @@ -337,6 +339,11 @@ def test_fused_marlin_moe( if not is_k_full: return + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 @@ -354,12 +361,27 @@ def test_fused_marlin_moe( w_ref1_l = [] qweight1_l = [] scales1_l = [] + global_scale1_l = [] zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref1, qweight1, scales1, global_scale1 = rand_marlin_weight_fp4_like( + w1[i], group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + global_scale1_l.append(global_scale1) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( + w1[i], group_size) + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( w1[i].transpose(1, 0), quant_type, group_size) @@ -367,7 +389,7 @@ def test_fused_marlin_moe( qweight1_l.append(qweight1) scales1_l.append(scales1) zeros1_l.append(zeros1) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(k) w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ marlin_quantize(w1[i].transpose(1, 0), quant_type, @@ -378,16 +400,11 @@ def test_fused_marlin_moe( scales1_l.append(scales1) g_idx1_l.append(g_idx1) sort_indices1_l.append(sort_indices1) - else: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size, True) - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) + global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None @@ -395,12 +412,27 @@ def test_fused_marlin_moe( w_ref2_l = [] qweight2_l = [] scales2_l = [] + global_scale2_l = [] zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref2, qweight2, scales2, global_scale2 = rand_marlin_weight_fp4_like( + w2[i], group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + global_scale2_l.append(global_scale2) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( + w2[i], group_size) + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( w2[i].transpose(1, 0), quant_type, group_size) @@ -408,7 +440,7 @@ def test_fused_marlin_moe( qweight2_l.append(qweight2) scales2_l.append(scales2) zeros2_l.append(zeros2) - elif quant_type != scalar_types.float8_e4m3fn: + else: test_perm = torch.randperm(n) w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ marlin_quantize(w2[i].transpose(1, 0), quant_type, @@ -419,24 +451,18 @@ def test_fused_marlin_moe( scales2_l.append(scales2) g_idx2_l.append(g_idx2) sort_indices2_l.append(sort_indices2) - else: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size, True) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) + global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) @@ -451,6 +477,8 @@ def test_fused_marlin_moe( topk_ids, global_num_experts=e, expert_map=e_map, + global_scale1=global_scale1, + global_scale2=global_scale2, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9d68ef4ca062..62a1c3fd4ef6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1268,6 +1268,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor], g_idx: Optional[torch.Tensor], perm: Optional[torch.Tensor], @@ -1282,11 +1283,11 @@ def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], use_fp32_reduce: bool, is_zp_float: bool) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_scales, b_qzeros, g_idx, perm, workspace, - sorted_token_ids, expert_ids, num_tokens_past_padded, topk_weights, - moe_block_size, top_k, mul_topk_weights, is_ep, b_q_type.id, size_m, - size_n, size_k, is_k_full, use_atomic_add, use_fp32_reduce, - is_zp_float) + input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, + perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, + topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, + use_fp32_reduce, is_zp_float) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b96d34ec2db3..4c84dd538332 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -25,6 +25,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -64,11 +66,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor, quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn + scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f ] - int4_scalar_types = [scalar_types.uint4, scalar_types.uint4b8] - num_bits = 4 if quant_type in int4_scalar_types else 8 + bit4_scalar_types = [ + scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. assert hidden_states.shape[0] == gating_output.shape[ @@ -133,6 +137,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache1, w1, w1_scale, + global_scale1, w1_zeros, g_idx1, sort_indices1, @@ -165,6 +170,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, intermediate_cache3, w2, w2_scale, + global_scale2, w2_zeros, g_idx2, sort_indices2, @@ -202,6 +208,8 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, global_num_experts: int = -1, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 287a14e4da1c..30ca778d843e 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -293,7 +293,7 @@ def pack_fp8_to_int32(fp8_tensor: torch.Tensor, return int32_tensor.T.contiguous() if size_k_first else int32_tensor -def marlin_quant_fp8_torch(weight, group_size, is_moe=False): +def marlin_quant_fp8_torch(weight, group_size): size_n, size_k = weight.shape device = weight.device @@ -321,7 +321,5 @@ def marlin_quant_fp8_torch(weight, group_size, is_moe=False): size_k=size_k, size_n=size_n, group_size=group_size) - if not is_moe: - marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) return weight_ref.T, marlin_qweight, marlin_scales From d7b2ac717ad67bdc69344e4f6620df66816c2184 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 20:18:53 +0800 Subject: [PATCH 26/35] fix Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index c293b121310b..6bf4d3fe36a8 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -911,7 +911,7 @@ torch::Tensor moe_wna16_marlin_gemm( if (b_q_type == vllm::kFE2M1f) { scales_ptr = b_scales.data_ptr(); } else { - scales_ptr = b_scales.data_ptr(); + scales_ptr = b_scales.data_ptr(); } MARLIN_NAMESPACE_NAME::marlin_mm( From f09273b59ab3938310cef35bd266931f090b5653 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 20:26:22 +0800 Subject: [PATCH 27/35] fix Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/ops.cu | 2 +- tests/kernels/moe/test_moe.py | 12 ++++++------ .../layers/quantization/utils/marlin_utils_fp8.py | 2 ++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 6bf4d3fe36a8..2cff04f699b0 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -46,7 +46,7 @@ __global__ void permute_cols_kernel( const int32_t* __restrict__ sorted_token_ids_ptr, const int32_t* __restrict__ expert_ids_ptr, const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m, - int size_k, int top_k){}; + int size_k, int top_k) {}; } // namespace marlin diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ad452d2d899f..1dbb5f417ce1 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -16,10 +16,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( marlin_quant_fp8_torch) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -368,8 +368,8 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref1, qweight1, scales1, global_scale1 = rand_marlin_weight_fp4_like( - w1[i], group_size) + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_fp4_like(w1[i], group_size) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -419,8 +419,8 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: - w_ref2, qweight2, scales2, global_scale2 = rand_marlin_weight_fp4_like( - w2[i], group_size) + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_fp4_like(w2[i], group_size) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 30ca778d843e..3080d2a0da87 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -322,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size): size_n=size_n, group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + return weight_ref.T, marlin_qweight, marlin_scales From a82fcbf2ed7b582bd6df75dcd32f9d1d121093ff Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 21:08:47 +0800 Subject: [PATCH 28/35] add comment Signed-off-by: Jinzhen Lin --- .../quantization/utils/marlin_utils_fp4.py | 70 +++++++++++-------- 1 file changed, 40 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 466230fb66bf..9ab9783d45ae 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -21,18 +21,46 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(80) -def fp4_fused_exponent_bias_into_scales(scales): +def fp4_marlin_process_scales(marlin_scales): + assert (marlin_scales >= 0).all() + + # convert to half first we would convert to fp8 later + marlin_scales = marlin_scales.torch(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + + # Why multiply 2 ** 7 ? + # After by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # hen weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + marlin_scales = marlin_scales.to(torch.float8_e4m3fn) + marlin_scales = ops.marlin_fp8_scales_preprocess(marlin_scales) + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 - if scales.dtype == torch.half: + if global_scale.dtype == torch.half: target_exponent = 5 - elif scales.dtype == torch.bfloat16: + elif global_scale.dtype == torch.bfloat16: target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) - s = torch.ones_like(scales) * 2 - s = s**exponent_bias - return scales * s + return global_scale * (2.0 * (exponent_bias - 7)) def apply_fp4_marlin_linear( @@ -116,22 +144,12 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: size_k=part_size_k, size_n=part_size_n, group_size=16) - - weight_scale = weight_scale.view(weight_scale.size(0) // 2, 2, -1, 8) - weight_scale = weight_scale.permute(0, 2, 1, 3).reshape( - weight_scale.size(0) * 2, -1) - weight_scale = weight_scale.view(-1, 4)[:, [0, 2, 1, 3]].view( - weight_scale.size(0), -1).to(torch.float8_e4m3fn) - weight_scale = ops.marlin_fp8_scales_preprocess(weight_scale) + weight_scale = fp4_marlin_process_scales(weight_scale) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - weight_scale_2 = layer.weight_scale_2 - if param_dtype == torch.half: - weight_scale_2 = weight_scale_2 * (2.0**7) - elif param_dtype == torch.bfloat16: - weight_scale_2 = weight_scale_2 * (2.0**119) - - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2.to(param_dtype), + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) return @@ -287,16 +305,8 @@ def rand_marlin_weight_fp4_like(weight, group_size): size_k=size_k, size_n=size_n, group_size=group_size) - marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) - marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1) - marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1).to(torch.float8_e4m3fn) - marlin_scales = ops.marlin_fp8_scales_preprocess(marlin_scales) + marlin_scales = fp4_marlin_process_scales(marlin_scales) - if weight.dtype == torch.half: - global_scale = global_scale * (2.0**7) - elif weight.dtype == torch.bfloat16: - global_scale = global_scale * (2.0**119) + global_scale = fp4_marlin_process_global_scale(global_scale) return weight_ref.T, marlin_qweight, marlin_scales, global_scale From 7177a72d053b81553b9d784d021a466459dfed05 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 21:19:39 +0800 Subject: [PATCH 29/35] fix Signed-off-by: Jinzhen Lin --- .../quantization/utils/marlin_utils_fp4.py | 48 ++++--------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 9ab9783d45ae..6c8090c1522f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -168,6 +168,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WORKSPACE device = layer.w13_weight.device + param_dtype = layer.params_dtype layer.workspace = marlin_make_workspace_new(device, 4) perm = torch.empty(0, dtype=torch.int, device=device) @@ -206,14 +207,8 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: group_size = layer.weight_block_size[1] for name in ["w13", "w2"]: - if name + "_weight_scale" in dir(layer): - new_name = name + "_weight_scale" - scales = getattr(layer, new_name).to(layer.orig_dtype) - delattr(layer, new_name) - elif name + "_weight_scale_inv" in dir(layer): - new_name = name + "_weight_scale_inv" - scales = getattr(layer, new_name).to(layer.orig_dtype) - delattr(layer, new_name) + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -221,47 +216,22 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: else: size_n, size_k = k, n - # marlin kernel only support channel-wise and group-wise quantization - # we need to convert the scales - if layer.weight_block_size is None: - if scales.nelement() == e: - # tensor-wise quantization -> channel-wise quantization - # (e, 1, 1) =>(repeat)=> (e, 1, size_n) - scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2) - elif scales.nelement() > e and scales.nelement() != e * size_n: - assert (e * size_n) % scales.nelement() == 0 - s_size = scales.nelement() // e - # tensor-wise quantization (for gate-up proj) - # -> channel-wise quantization - # (e, 1, s_size) =>(repeat)=> (e, 1, size_n) - scales = scales.view(e, 1, s_size) - scales = scales.repeat_interleave(size_n // s_size, 2) - else: - # channel-wise quantization - # (e, 1, size_n) - scales = scales.view(e, 1, size_n) - else: - # block-wise quantization -> group-wise quantization - # (e, size_k // block_size[1], ceil(size_n / block_size[0])) - # =>(repeat)=> (e, size_k // block_size[1], size_n) - block_n = layer.weight_block_size[0] - scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) - # size_n may not divisible by block_size[0] - scales = scales[..., :size_n].contiguous() - for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i], + marlin_scales = marlin_permute_scales(s=scales[i].T, size_k=size_k, size_n=size_n, group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - scales = fp4_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) - setattr(layer, name + "_weight_scale", scales) + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + def rand_marlin_weight_fp4_like(weight, group_size): assert group_size > 0 From 4a6ac2a1a453914d666a2020aa02dd45b3d5cef2 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 21:44:16 +0800 Subject: [PATCH 30/35] fix Signed-off-by: Jinzhen Lin --- .../layers/quantization/utils/marlin_utils_fp4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 6c8090c1522f..bc6b66ea5c10 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -25,7 +25,7 @@ def fp4_marlin_process_scales(marlin_scales): assert (marlin_scales >= 0).all() # convert to half first we would convert to fp8 later - marlin_scales = marlin_scales.torch(torch.half) + marlin_scales = marlin_scales.to(torch.half) # 8 is the number of scale number using by one thread marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) @@ -60,7 +60,7 @@ def fp4_marlin_process_global_scale(global_scale): # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) - return global_scale * (2.0 * (exponent_bias - 7)) + return global_scale * (2.0**(exponent_bias - 7)) def apply_fp4_marlin_linear( From e6144ee5c9e313e80edea8720fd9a0ca94c10713 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 22:48:29 +0800 Subject: [PATCH 31/35] remove unused cuda kernel Signed-off-by: Jinzhen Lin --- CMakeLists.txt | 1 - .../marlin_fp8_scales_preproocess.cu | 74 ------------------- csrc/torch_bindings.cpp | 5 -- vllm/_custom_ops.py | 4 - .../quantization/utils/marlin_utils_fp4.py | 12 +-- 5 files changed, 6 insertions(+), 90 deletions(-) delete mode 100644 csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 8888ca7a5278..4b3bfe0af7f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -352,7 +352,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") set_gencode_flags_for_srcs( SRCS "${MARLIN_SRCS}" diff --git a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu b/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu deleted file mode 100644 index f5f1b55198a0..000000000000 --- a/csrc/quantization/gptq_marlin/marlin_fp8_scales_preproocess.cu +++ /dev/null @@ -1,74 +0,0 @@ - -#include -#include -#include "core/registration.h" - -__global__ void marlin_fp8_scales_preprocess_kernel(int4* __restrict__ in_ptr, - int4* __restrict__ out_ptr, - int64_t s_size) { - // convert subnormal fp8_e4m3 value to fp8_e5m3_val - // #0bEEEEEMMM // subnormal_e4m3_val = e5m3_val - const uint8_t subnormal_val_map[9] = { - 0b00000000, // 0 / 2 = 0 - 0b00110000, // 1 / 8 * (2 ** -6) = 1.00 * (2 ** (6 - 15)) - 0b00111000, // 2 / 8 * (2 ** -6) = 1.00 * (2 ** (7 - 15)) - 0b00111100, // 3 / 8 * (2 ** -6) = 1.50 * (2 ** (7 - 15)) - 0b01000000, // 4 / 8 * (2 ** -6) = 1.00 * (2 ** (8 - 15)) - 0b01000010, // 5 / 8 * (2 ** -6) = 1.25 * (2 ** (8 - 15)) - 0b01000100, // 6 / 8 * (2 ** -6) = 1.50 * (2 ** (8 - 15)) - 0b01000110, // 7 / 8 * (2 ** -6) = 1.75 * (2 ** (8 - 15)) - }; - - int offset = blockIdx.x * blockDim.x; - - // Note that after the conversion, - // the first bit of all values (except 0.0) is 1 - auto process_val = [&](uint8_t val) { - if (val == 0) return 0; - - // normalized value case - // (x | 0x80): set the top bit of exponent to 1 - // so that we have less exponent bias with fp16/bf16 - // (x - 8): divide the fp8 value by 2 - // to avoid the value become NaN after dequantization - // when x = *reinterpret_cast(&fp8_val) - // (x - 8 * y) means the exponent is decreased by y, - // which corresponds to dividing the fp8 value by 2 ** y - else if (val >= 8) - return (val | 0x80) - 8; - - // subnormal value (all exponent bits is 0) - // (x - 8 * 8): to match the exponent bias used by normalized numbers - // (x - 8): same with normalized value case - else - return (subnormal_val_map[val] | 0x80) - 8 * (8 + 1); - }; - - for (int i = offset + threadIdx.x; i < s_size / 16; i += blockDim.x) { - int4 val = in_ptr[i]; - uint8_t* vals = reinterpret_cast(&val); - -#pragma unroll - for (int j = 0; j < 16; j++) vals[j] = process_val(vals[j]); - - out_ptr[i] = *reinterpret_cast(vals); - } -}; - -torch::Tensor marlin_fp8_scales_preprocess(torch::Tensor scales) { - TORCH_CHECK(scales.device().is_cuda(), "scales is not on GPU"); - - int dev = scales.get_device(); - torch::Tensor out_scales = torch::empty_like(scales); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - marlin_fp8_scales_preprocess_kernel<<<256, 512, 0, stream>>>( - reinterpret_cast(scales.data_ptr()), - reinterpret_cast(out_scales.data_ptr()), - scales.nbytes()); - - return out_scales; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_fp8_scales_preprocess", &marlin_fp8_scales_preprocess); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9797c786c9fc..a39217b3770b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -299,11 +299,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); // conditionally compiled so impl registration is in source file - // process marlin fp8 scales (used for W4A16-FP4) - ops.def("marlin_fp8_scales_preprocess(Tensor scales) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 62a1c3fd4ef6..c0a37c9db3f2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -779,10 +779,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def marlin_fp8_scales_preprocess(scales: torch.Tensor) -> torch.Tensor: - return torch.ops._C.marlin_fp8_scales_preprocess(scales) - - def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index bc6b66ea5c10..bab5961a9685 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -24,7 +24,7 @@ def is_fp4_marlin_supported(): def fp4_marlin_process_scales(marlin_scales): assert (marlin_scales >= 0).all() - # convert to half first we would convert to fp8 later + # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half) # 8 is the number of scale number using by one thread @@ -39,13 +39,13 @@ def fp4_marlin_process_scales(marlin_scales): # We assume that weight_scale (FP8-S1E4M3) is always greater # than or equal to 0. So we can convert # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. - - # Why multiply 2 ** 7 ? - # After by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # After multiply by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 # hen weight_scale > 0. This allows us to have an exponent bias # closer to zero after dequantization. - marlin_scales = marlin_scales.to(torch.float8_e4m3fn) - marlin_scales = ops.marlin_fp8_scales_preprocess(marlin_scales) + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() return marlin_scales From 45910c1ac8f79228963eb78a47f9724c5dd4845c Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 9 May 2025 23:06:06 +0800 Subject: [PATCH 32/35] fix Signed-off-by: Jinzhen Lin --- .../layers/quantization/utils/marlin_utils_fp4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index bab5961a9685..aed3507b23e6 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -39,8 +39,8 @@ def fp4_marlin_process_scales(marlin_scales): # We assume that weight_scale (FP8-S1E4M3) is always greater # than or equal to 0. So we can convert # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. - # After multiply by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 - # hen weight_scale > 0. This allows us to have an exponent bias + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias # closer to zero after dequantization. marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 From 7e0dbe854fcd630b58e31e9ac6bac3cdb7c9f02f Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 10 May 2025 10:10:29 +0800 Subject: [PATCH 33/35] fix test Signed-off-by: Jinzhen Lin --- tests/kernels/moe/test_moe.py | 71 ++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index d30a1b0f01ba..c1d0940f26cb 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -288,21 +288,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) +def marlin_moe_generate_valid_test_cases(): + import itertools + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + ep_size_list = [1, 4] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [-1, 16, 32, 128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.float4_e2m1f, + scalar_types.float8_e4m3fn, + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.uint8b128, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product(m_list, n_list, k_list, e_list, + topk_list, ep_size_list, dtype_list, + group_size_list, act_order_list, + quant_type_list, is_k_full_list) + + def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, + quant_type, is_k_full): + + if quant_type == scalar_types.float8_e4m3fn and \ + group_size not in [-1, 128]: + return False + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return False + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return False + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + @pytest.mark.flaky(reruns=2) -@pytest.mark.parametrize("m", [1, 123, 666]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("ep_size", [1, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 16, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("quant_type", [ - scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f -]) -@pytest.mark.parametrize("is_k_full", [True, False]) +@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," + "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases()) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, From 18df7ec46889c1fb3e5312e16ad386d0471f5653 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 10 May 2025 15:57:15 +0800 Subject: [PATCH 34/35] fp4 moe support Signed-off-by: Jinzhen Lin --- csrc/moe/marlin_moe_wna16/marlin_template.h | 17 ++++++++++---- .../layers/quantization/modelopt.py | 22 +++++++++++++++++-- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 752454d37644..dedbe1b792f7 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -473,9 +473,16 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + if constexpr (w_type == vllm::kFE2M1f) { + sh_block_topk_weights[tid4 * 4 + i] = __hmul2( + global_scale, + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]))); + } else { + sh_block_topk_weights[tid4 * 4 + i] = + Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + } } } } @@ -1600,7 +1607,9 @@ __global__ void Marlin( } if constexpr (w_type == vllm::kFE2M1f) { - res = __hmul2(res, global_scale); + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } } if constexpr (m_block_size_8) { diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index cd69dc9c2b45..bc0b752a2426 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin) + prepare_moe_fp4_layer_for_marlin, prepare_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -460,6 +460,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + + if not self.cutlass_nvfp4_supported: + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -620,7 +630,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, requires_grad=False) - return + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g13_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + del layer.w13_blockscale_swizzled + del layer.w2_blockscale_swizzled def apply( self, From a21442d749012f3c394c4821171904f429261ab0 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 10 May 2025 17:28:39 +0800 Subject: [PATCH 35/35] fix moe support Signed-off-by: Jinzhen Lin --- .../layers/quantization/modelopt.py | 36 +++++++++++++++++-- .../quantization/utils/marlin_utils_fp4.py | 9 ++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bc0b752a2426..bd9daa7c608a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_moe_fp4_layer_for_marlin, prepare_fp4_layer_for_marlin) + prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -27,6 +27,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -478,6 +479,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise ValueError("NVFP4 quantization was selected, " " dynamic quantization is not supported.") + layer.num_experts = num_experts + layer.params_dtype = params_dtype layer.quant_config = self.quant_config weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn @@ -633,7 +636,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) - del layer.g13_alphas + del layer.g1_alphas del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant @@ -658,6 +661,35 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ): + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + assert activation == "silu", "Only SiLU activation is supported." assert not apply_router_weight_on_input, ( "Router weight on input is not " diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index aed3507b23e6..15177af58ae6 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -191,7 +191,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: perm=perm, size_k=size_k, size_n=size_n, - num_bits=8) + num_bits=4) tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -201,11 +201,6 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Permute scales - if layer.weight_block_size is None: - group_size = -1 - else: - group_size = layer.weight_block_size[1] - for name in ["w13", "w2"]: scales = getattr(layer, name + "_weight_scale").to(param_dtype) global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) @@ -220,7 +215,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: marlin_scales = marlin_permute_scales(s=scales[i].T, size_k=size_k, size_n=size_n, - group_size=group_size) + group_size=16) marlin_scales = fp4_marlin_process_scales(marlin_scales) tensor_list.append(marlin_scales)