Skip to content

Commit ae2e8c2

Browse files
elvischenvyewentao256
authored andcommitted
[Bugfix] Fix accuracy issue for silu_mul + nvfp4 quant fusion kernel (vllm-project#24833)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 8935751 commit ae2e8c2

File tree

5 files changed

+110
-226
lines changed

5 files changed

+110
-226
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ steps:
796796
# Quantization
797797
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
798798
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
799-
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
799+
- pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
800800
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
801801
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
802802
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py

csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu

Lines changed: 27 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -30,109 +30,41 @@
3030

3131
namespace vllm {
3232

33-
template <class Type>
34-
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
35-
PackedVec<Type>& vec2) {
36-
PackedVec<Type> result;
37-
#pragma unroll
38-
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
39-
if constexpr (std::is_same_v<Type, half>) {
40-
half2 val(0.5f, 0.5f);
41-
half2 t0 = __hmul2(vec.elts[i], val);
42-
half2 t1 = __hfma2(h2tanh(t0), val, val);
43-
half2 t2 = __hmul2(vec.elts[i], t1);
44-
result.elts[i] = __hmul2(t2, vec2.elts[i]);
45-
} else {
46-
__nv_bfloat162 val(0.5f, 0.5f);
47-
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
48-
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
49-
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
50-
result.elts[i] = __hmul2(t2, vec2.elts[i]);
51-
}
52-
}
53-
return result;
33+
// silu in float32
34+
__device__ __forceinline__ float silu(float x) {
35+
return __fdividef(x, (1.f + __expf(-x)));
5436
}
5537

56-
// Quantizes the provided PackedVec into the uint32_t output
57-
template <class Type, bool UE8M0_SF = false>
58-
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
59-
PackedVec<Type>& vec2,
60-
float SFScaleVal,
61-
uint8_t* SFout) {
62-
PackedVec<Type> out_silu = compute_silu(vec, vec2);
63-
// Get absolute maximum values among the local 8 values.
64-
auto localMax = __habs2(out_silu.elts[0]);
65-
66-
// Local maximum value.
67-
#pragma unroll
68-
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
69-
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
70-
}
71-
72-
// Get the absolute maximum among all 16 values (two threads).
73-
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
74-
// Get the final absolute maximum values.
75-
float vecMax = float(__hmax(localMax.x, localMax.y));
76-
77-
// Get the SF (max value of the vector / max value of e2m1).
78-
// maximum value of e2m1 = 6.0.
79-
// TODO: use half as compute data type.
80-
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
81-
// 8 bits representation of the SF.
82-
uint8_t fp8SFVal;
83-
// Write the SF to global memory (STG.8).
84-
if constexpr (UE8M0_SF) {
85-
// Extract the 8 exponent bits from float32.
86-
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
87-
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
88-
fp8SFVal = tmp & 0xff;
89-
// Convert back to fp32.
90-
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
91-
} else {
92-
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
93-
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
94-
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
95-
// Convert back to fp32.
96-
SFValue = float(tmp);
97-
}
98-
// Get the output scale.
99-
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
100-
// reciprocal(SFScaleVal))
101-
float outputScale =
102-
SFValue != 0 ? reciprocal_approximate_ftz(
103-
SFValue * reciprocal_approximate_ftz(SFScaleVal))
104-
: 0.0f;
105-
106-
if (SFout) {
107-
// Write the SF to global memory (STG.8).
108-
*SFout = fp8SFVal;
109-
}
38+
__device__ __forceinline__ float2 silu2(float2 x) {
39+
return make_float2(silu(x.x), silu(x.y));
40+
}
11041

111-
// Convert the input to float.
112-
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
42+
template <class Type>
43+
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
44+
PackedVec<Type>& vec2) {
45+
PackedVec<Type> result;
46+
using packed_type = typename TypeConverter<Type>::Type;
11347

11448
#pragma unroll
115-
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
49+
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
50+
// silu_mul in float32
11651
if constexpr (std::is_same_v<Type, half>) {
117-
fp2Vals[i] = __half22float2(out_silu.elts[i]);
52+
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
53+
result.elts[i] =
54+
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
11855
} else {
119-
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
56+
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
57+
result.elts[i] = __float22bfloat162_rn(
58+
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
12059
}
121-
fp2Vals[i].x *= outputScale;
122-
fp2Vals[i].y *= outputScale;
12360
}
124-
125-
// Convert to e2m1 values.
126-
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
127-
128-
// Write the e2m1 values to global memory.
129-
return e2m1Vec;
61+
return result;
13062
}
13163

13264
// Use UE4M3 by default.
13365
template <class Type, bool UE8M0_SF = false>
13466
__global__ void __launch_bounds__(1024, 4)
135-
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
67+
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
13668
float const* SFScale, uint32_t* out,
13769
uint32_t* SFout) {
13870
using PackedVec = PackedVec<Type>;
@@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4)
16092
// Get the output tensor offset.
16193
// Same as inOffset because 8 elements are packed into one uint32_t.
16294
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
163-
;
16495
auto& out_pos = out[outOffset];
16596

97+
// Compute silu and mul
98+
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);
99+
166100
auto sf_out =
167101
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
168102
CVT_FP4_NUM_THREADS_PER_SF>(
169103
rowIdx, colIdx, numCols, SFout);
170104

171-
out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
172-
in_vec, in_vec2, SFScaleVal, sf_out);
105+
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
106+
sf_out);
173107
}
174108
}
175109
}
@@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
204138
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
205139
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
206140
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
207-
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
141+
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
208142
m, n, input_ptr, input_sf_ptr,
209143
reinterpret_cast<uint32_t*>(output_ptr),
210144
reinterpret_cast<uint32_t*>(sf_out));

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def ops_in_model_after(self):
9898
return [FUSED_OPS[kNvfp4Quant]]
9999

100100

101-
@pytest.mark.parametrize("num_tokens", [64])
102-
@pytest.mark.parametrize("hidden_size", [128])
101+
@pytest.mark.parametrize("num_tokens", [32, 64])
102+
@pytest.mark.parametrize("hidden_size", [128, 256])
103+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
103104
@pytest.mark.parametrize(
104105
"model_class",
105106
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
@@ -110,13 +111,13 @@ def ops_in_model_after(self):
110111
[True, False] if cutlass_fp8_supported() else [True])
111112
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
112113
reason="Only test on CUDA and ROCm")
113-
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
114+
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
114115
cuda_force_torch):
115116
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
116117
pytest.skip("Duplicate tests for NVFP4")
117118

118119
torch.set_default_device("cuda")
119-
torch.set_default_dtype(torch.float16)
120+
torch.set_default_dtype(dtype)
120121

121122
x = torch.rand(num_tokens, hidden_size * 2)
122123

@@ -145,8 +146,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
145146
elif model_class == TestSiluMulNvfp4QuantModel:
146147
atol, rtol = 1e-1, 1e-1
147148

148-
torch.testing.assert_close(result[0].to(dtype=torch.float16),
149-
result2[0].to(dtype=torch.float16),
149+
torch.testing.assert_close(result[0].to(dtype=dtype),
150+
result2[0].to(dtype=dtype),
150151
atol=atol,
151152
rtol=rtol)
152153

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
6+
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
7+
FLOAT8_E4M3_MAX,
8+
dequantize_nvfp4_to_dtype)
9+
from vllm._custom_ops import scaled_fp4_quant
10+
from vllm.model_executor.layers.activation import SiluAndMul
11+
from vllm.platforms import current_platform
12+
13+
if not current_platform.has_device_capability(100):
14+
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
15+
allow_module_level=True)
16+
17+
FP4_DTYPE = torch.uint8
18+
FP8_DTYPE = current_platform.fp8_dtype()
19+
20+
DTYPES = [torch.float16, torch.bfloat16]
21+
SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)]
22+
BLOCK_SIZE = 16
23+
24+
25+
@pytest.mark.parametrize("dtype", DTYPES)
26+
@pytest.mark.parametrize("shape", SHAPES)
27+
@torch.inference_mode()
28+
def test_silu_mul_nvfp4_quant(
29+
dtype: torch.dtype,
30+
shape: tuple[int, int],
31+
) -> None:
32+
current_platform.seed_everything(42)
33+
device = 'cuda:0'
34+
torch.set_default_device(device)
35+
36+
x = torch.randn(shape, dtype=dtype)
37+
38+
# ref op
39+
ref_output = SiluAndMul().forward_native(x)
40+
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
41+
torch.abs(ref_output).max().to(torch.float32))
42+
ref_output_quant, ref_block_scale = scaled_fp4_quant(
43+
ref_output, ref_global_scale)
44+
45+
# fused op
46+
fused_output_quant = torch.empty_like(ref_output_quant)
47+
fused_block_scale = torch.empty_like(ref_block_scale)
48+
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
49+
fused_block_scale, x,
50+
ref_global_scale)
51+
52+
# check dtype
53+
assert ref_output_quant.dtype == FP4_DTYPE
54+
assert fused_output_quant.dtype == FP4_DTYPE
55+
assert ref_output_quant.shape == fused_output_quant.shape
56+
57+
assert ref_block_scale.dtype == FP8_DTYPE
58+
assert fused_block_scale.dtype == FP8_DTYPE
59+
assert ref_block_scale.shape == fused_block_scale.shape
60+
61+
# check dequantized output
62+
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
63+
ref_block_scale,
64+
ref_global_scale, dtype,
65+
device)
66+
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
67+
fused_block_scale,
68+
ref_global_scale, dtype,
69+
device)
70+
71+
atol, rtol = 3e-1, 3e-1
72+
torch.testing.assert_close(ref_output_dequant,
73+
fused_output_dequant,
74+
atol=atol,
75+
rtol=rtol)

0 commit comments

Comments
 (0)