|
30 | 30 |
|
31 | 31 | namespace vllm { |
32 | 32 |
|
33 | | -template <class Type> |
34 | | -__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec, |
35 | | - PackedVec<Type>& vec2) { |
36 | | - PackedVec<Type> result; |
37 | | -#pragma unroll |
38 | | - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { |
39 | | - if constexpr (std::is_same_v<Type, half>) { |
40 | | - half2 val(0.5f, 0.5f); |
41 | | - half2 t0 = __hmul2(vec.elts[i], val); |
42 | | - half2 t1 = __hfma2(h2tanh(t0), val, val); |
43 | | - half2 t2 = __hmul2(vec.elts[i], t1); |
44 | | - result.elts[i] = __hmul2(t2, vec2.elts[i]); |
45 | | - } else { |
46 | | - __nv_bfloat162 val(0.5f, 0.5f); |
47 | | - __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); |
48 | | - __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); |
49 | | - __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); |
50 | | - result.elts[i] = __hmul2(t2, vec2.elts[i]); |
51 | | - } |
52 | | - } |
53 | | - return result; |
| 33 | +// silu in float32 |
| 34 | +__device__ __forceinline__ float silu(float x) { |
| 35 | + return __fdividef(x, (1.f + __expf(-x))); |
54 | 36 | } |
55 | 37 |
|
56 | | -// Quantizes the provided PackedVec into the uint32_t output |
57 | | -template <class Type, bool UE8M0_SF = false> |
58 | | -__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, |
59 | | - PackedVec<Type>& vec2, |
60 | | - float SFScaleVal, |
61 | | - uint8_t* SFout) { |
62 | | - PackedVec<Type> out_silu = compute_silu(vec, vec2); |
63 | | - // Get absolute maximum values among the local 8 values. |
64 | | - auto localMax = __habs2(out_silu.elts[0]); |
65 | | - |
66 | | -// Local maximum value. |
67 | | -#pragma unroll |
68 | | - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { |
69 | | - localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); |
70 | | - } |
71 | | - |
72 | | - // Get the absolute maximum among all 16 values (two threads). |
73 | | - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); |
74 | | - // Get the final absolute maximum values. |
75 | | - float vecMax = float(__hmax(localMax.x, localMax.y)); |
76 | | - |
77 | | - // Get the SF (max value of the vector / max value of e2m1). |
78 | | - // maximum value of e2m1 = 6.0. |
79 | | - // TODO: use half as compute data type. |
80 | | - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); |
81 | | - // 8 bits representation of the SF. |
82 | | - uint8_t fp8SFVal; |
83 | | - // Write the SF to global memory (STG.8). |
84 | | - if constexpr (UE8M0_SF) { |
85 | | - // Extract the 8 exponent bits from float32. |
86 | | - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. |
87 | | - uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23; |
88 | | - fp8SFVal = tmp & 0xff; |
89 | | - // Convert back to fp32. |
90 | | - reinterpret_cast<uint32_t&>(SFValue) = tmp << 23; |
91 | | - } else { |
92 | | - // Here SFValue is always positive, so E4M3 is the same as UE4M3. |
93 | | - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); |
94 | | - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; |
95 | | - // Convert back to fp32. |
96 | | - SFValue = float(tmp); |
97 | | - } |
98 | | - // Get the output scale. |
99 | | - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * |
100 | | - // reciprocal(SFScaleVal)) |
101 | | - float outputScale = |
102 | | - SFValue != 0 ? reciprocal_approximate_ftz( |
103 | | - SFValue * reciprocal_approximate_ftz(SFScaleVal)) |
104 | | - : 0.0f; |
105 | | - |
106 | | - if (SFout) { |
107 | | - // Write the SF to global memory (STG.8). |
108 | | - *SFout = fp8SFVal; |
109 | | - } |
| 38 | +__device__ __forceinline__ float2 silu2(float2 x) { |
| 39 | + return make_float2(silu(x.x), silu(x.y)); |
| 40 | +} |
110 | 41 |
|
111 | | - // Convert the input to float. |
112 | | - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; |
| 42 | +template <class Type> |
| 43 | +__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec, |
| 44 | + PackedVec<Type>& vec2) { |
| 45 | + PackedVec<Type> result; |
| 46 | + using packed_type = typename TypeConverter<Type>::Type; |
113 | 47 |
|
114 | 48 | #pragma unroll |
115 | | - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { |
| 49 | + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { |
| 50 | + // silu_mul in float32 |
116 | 51 | if constexpr (std::is_same_v<Type, half>) { |
117 | | - fp2Vals[i] = __half22float2(out_silu.elts[i]); |
| 52 | + float2 silu_vec = silu2(__half22float2(vec.elts[i])); |
| 53 | + result.elts[i] = |
| 54 | + __float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i]))); |
118 | 55 | } else { |
119 | | - fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); |
| 56 | + float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i])); |
| 57 | + result.elts[i] = __float22bfloat162_rn( |
| 58 | + __fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i]))); |
120 | 59 | } |
121 | | - fp2Vals[i].x *= outputScale; |
122 | | - fp2Vals[i].y *= outputScale; |
123 | 60 | } |
124 | | - |
125 | | - // Convert to e2m1 values. |
126 | | - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); |
127 | | - |
128 | | - // Write the e2m1 values to global memory. |
129 | | - return e2m1Vec; |
| 61 | + return result; |
130 | 62 | } |
131 | 63 |
|
132 | 64 | // Use UE4M3 by default. |
133 | 65 | template <class Type, bool UE8M0_SF = false> |
134 | 66 | __global__ void __launch_bounds__(1024, 4) |
135 | | - silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, |
| 67 | + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, |
136 | 68 | float const* SFScale, uint32_t* out, |
137 | 69 | uint32_t* SFout) { |
138 | 70 | using PackedVec = PackedVec<Type>; |
@@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4) |
160 | 92 | // Get the output tensor offset. |
161 | 93 | // Same as inOffset because 8 elements are packed into one uint32_t. |
162 | 94 | int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; |
163 | | - ; |
164 | 95 | auto& out_pos = out[outOffset]; |
165 | 96 |
|
| 97 | + // Compute silu and mul |
| 98 | + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); |
| 99 | + |
166 | 100 | auto sf_out = |
167 | 101 | cvt_quant_to_fp4_get_sf_out_offset<uint32_t, |
168 | 102 | CVT_FP4_NUM_THREADS_PER_SF>( |
169 | 103 | rowIdx, colIdx, numCols, SFout); |
170 | 104 |
|
171 | | - out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>( |
172 | | - in_vec, in_vec2, SFScaleVal, sf_out); |
| 105 | + out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal, |
| 106 | + sf_out); |
173 | 107 | } |
174 | 108 | } |
175 | 109 | } |
@@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] |
204 | 138 | input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { |
205 | 139 | using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; |
206 | 140 | auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); |
207 | | - vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>( |
| 141 | + vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>( |
208 | 142 | m, n, input_ptr, input_sf_ptr, |
209 | 143 | reinterpret_cast<uint32_t*>(output_ptr), |
210 | 144 | reinterpret_cast<uint32_t*>(sf_out)); |
|
0 commit comments