Skip to content

Commit bbb57ad

Browse files
authored
feat: trtrllm-gen global scaled FP8 GEMMs (#1829)
In low latency context, it is not uncommon to encounter memory bandwidth bound GEMMs with a tiny leading dimension M. These cases are currently not addressed as efficiently as they could by library implementations. To fill this gap, I propose to expose generated GEMM kernels optimized for small batch sizes, which saturate memory bandwidth to a higher degree. The main challenge in doing so is that these GEMMs expect the weight tensor (second operand) to be pre-processed into a layout more amenable to maximizing memory bandwidth saturation. As such it is not practical to expose them under the same API as the other GEMMs, as they are not interchangeable without changing the caller's implementation. I have tentatively exposed these GEMMs as "flavored" GEMMs, by contrast with the more "vanilla" GEMMs currently available. Summary of the changes: - Added cpp runner to be jitted for these new GEMMs: `csrc/trtllm_flavored_gemm_runner.cu` - A separate `flashinfer/trtllm_flavored_gemm.py` file containing the Python interface of the new GEMMs - Some stylistic refactoring of the autotuner done while understanding how it works - Tests - Benchmarks - Some other minor cleanups along the way. Note I will undo the extraction of `fp8_utils.py` as the implementations of `to_fp8` differ between the places I extracted it for Next step: I will add more kernels for larger batch sizes. This is required because the weight matrix shuffling commits the user to this interface. Therefore, they also need efficient kernels for larger batches, which they will encounter for prefills for example, when not doing disagg. Benchmarking results on GB200: ``` m=1 n=2560 k=16384 9.65 TFLOPs/s over 0.008694 ms, 4.83 TB/s m=1 n=2560 k=32768 11.34 TFLOPs/s over 0.014797 ms, 5.67 TB/s m=1 n=5120 k=16384 15.10 TFLOPs/s over 0.011110 ms, 7.55 TB/s m=1 n=5120 k=32768 12.21 TFLOPs/s over 0.027491 ms, 6.10 TB/s m=1 n=8192 k=16384 11.75 TFLOPs/s over 0.022851 ms, 5.87 TB/s m=1 n=8192 k=32768 13.06 TFLOPs/s over 0.041114 ms, 6.53 TB/s m=2 n=2560 k=16384 18.38 TFLOPs/s over 0.009130 ms, 4.60 TB/s m=2 n=2560 k=32768 21.21 TFLOPs/s over 0.015821 ms, 5.31 TB/s m=2 n=5120 k=16384 30.21 TFLOPs/s over 0.011107 ms, 7.56 TB/s m=2 n=5120 k=32768 24.41 TFLOPs/s over 0.027491 ms, 6.11 TB/s m=2 n=8192 k=16384 23.43 TFLOPs/s over 0.022912 ms, 5.86 TB/s m=2 n=8192 k=32768 26.15 TFLOPs/s over 0.041056 ms, 6.54 TB/s m=4 n=2560 k=16384 36.22 TFLOPs/s over 0.009264 ms, 4.54 TB/s m=4 n=2560 k=32768 43.55 TFLOPs/s over 0.015408 ms, 5.45 TB/s m=4 n=5120 k=16384 60.40 TFLOPs/s over 0.011110 ms, 7.56 TB/s m=4 n=5120 k=32768 48.82 TFLOPs/s over 0.027494 ms, 6.11 TB/s m=4 n=8192 k=16384 46.71 TFLOPs/s over 0.022989 ms, 5.84 TB/s m=4 n=8192 k=32768 52.10 TFLOPs/s over 0.041216 ms, 6.52 TB/s m=8 n=2560 k=16384 72.47 TFLOPs/s over 0.009261 ms, 4.55 TB/s m=8 n=2560 k=32768 84.84 TFLOPs/s over 0.015821 ms, 5.32 TB/s m=8 n=5120 k=16384 120.84 TFLOPs/s over 0.011107 ms, 7.57 TB/s m=8 n=5120 k=32768 97.37 TFLOPs/s over 0.027568 ms, 6.10 TB/s m=8 n=8192 k=16384 93.41 TFLOPs/s over 0.022989 ms, 5.85 TB/s m=8 n=8192 k=32768 104.21 TFLOPs/s over 0.041216 ms, 6.52 TB/s m=16 n=2560 k=16384 138.70 TFLOPs/s over 0.009677 ms, 4.37 TB/s m=16 n=2560 k=32768 174.22 TFLOPs/s over 0.015408 ms, 5.48 TB/s m=16 n=5120 k=16384 231.03 TFLOPs/s over 0.011619 ms, 7.26 TB/s m=16 n=5120 k=32768 190.13 TFLOPs/s over 0.028237 ms, 5.97 TB/s m=16 n=8192 k=16384 180.96 TFLOPs/s over 0.023734 ms, 5.68 TB/s m=16 n=8192 k=32768 205.52 TFLOPs/s over 0.041795 ms, 6.44 TB/s m=32 n=2560 k=16384 260.92 TFLOPs/s over 0.010288 ms, 4.14 TB/s m=32 n=2560 k=32768 322.64 TFLOPs/s over 0.016640 ms, 5.11 TB/s m=32 n=5120 k=16384 421.01 TFLOPs/s over 0.012752 ms, 6.65 TB/s m=32 n=5120 k=32768 371.18 TFLOPs/s over 0.028928 ms, 5.85 TB/s m=32 n=8192 k=16384 348.80 TFLOPs/s over 0.024627 ms, 5.49 TB/s m=32 n=8192 k=32768 400.89 TFLOPs/s over 0.042854 ms, 6.30 TB/s m=64 n=2560 k=16384 466.29 TFLOPs/s over 0.011514 ms, 3.76 TB/s m=64 n=2560 k=32768 458.96 TFLOPs/s over 0.023395 ms, 3.69 TB/s m=64 n=5120 k=16384 673.11 TFLOPs/s over 0.015952 ms, 5.37 TB/s m=64 n=5120 k=32768 679.79 TFLOPs/s over 0.031590 ms, 5.40 TB/s m=64 n=8192 k=16384 648.00 TFLOPs/s over 0.026512 ms, 5.14 TB/s m=64 n=8192 k=32768 766.41 TFLOPs/s over 0.044832 ms, 6.06 TB/s ```
1 parent fd03820 commit bbb57ad

File tree

25 files changed

+1074
-90
lines changed

25 files changed

+1074
-90
lines changed

benchmarks/bench_mm_fp8.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Dict
18+
from flashinfer.autotuner import autotune
19+
from flashinfer.trtllm_low_latency_gemm import prepare_low_latency_gemm_weights
20+
import numpy as np
21+
import torch
22+
23+
from flashinfer import mm_fp8
24+
from flashinfer.testing.utils import bench_gpu_time
25+
26+
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
27+
28+
29+
def to_float8(
30+
x: torch.Tensor, dtype=torch.float8_e4m3fn
31+
) -> tuple[torch.Tensor, torch.Tensor]:
32+
finfo = torch.finfo(dtype)
33+
min_val, max_val = x.aminmax()
34+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
35+
scale = finfo.max / amax
36+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
37+
return x_scl_sat.to(dtype), scale.float().reciprocal()
38+
39+
40+
def bench_mm_fp8(m, n, k, in_dtype, out_dtype):
41+
torch.manual_seed(123)
42+
input_tensor = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
43+
input_fp8, input_inv_s = to_float8(input_tensor, dtype=in_dtype)
44+
45+
# mat2 row major -> column major
46+
mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
47+
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=in_dtype)
48+
49+
res = torch.zeros([m, n], device="cuda", dtype=out_dtype)
50+
global_scale = input_inv_s * mat2_inv_s
51+
52+
# Do row shuffling.
53+
prepared_weights = prepare_low_latency_gemm_weights(
54+
mat2_fp8, _cache_permute_indices
55+
)
56+
57+
with autotune(True):
58+
mm_fp8(
59+
input_fp8,
60+
prepared_weights,
61+
global_scale,
62+
out=res,
63+
)
64+
65+
measurements = bench_gpu_time(
66+
lambda: mm_fp8(
67+
input_fp8,
68+
prepared_weights,
69+
global_scale,
70+
res,
71+
),
72+
dry_run_time_ms=500,
73+
repeat_time_ms=2500,
74+
use_cuda_graph=True,
75+
)
76+
ms = np.median(measurements)
77+
tflops_per_second = 2 * m * n * k * 1e-9 / ms
78+
79+
bandwidth = (
80+
(
81+
input_fp8.numel() * input_fp8.element_size()
82+
+ prepared_weights.numel() * prepared_weights.element_size()
83+
+ res.numel() * res.element_size()
84+
)
85+
/ ms
86+
/ 1e9
87+
)
88+
89+
print(
90+
f"mm_fp8 m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s over {ms:.6f} ms, {bandwidth:.2f} TB/s"
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
for m in [1, 2, 4, 8, 16, 32, 64]:
96+
for n in [2560, 5120, 8192]:
97+
for k in [16384, 32768]:
98+
bench_mm_fp8(m, n, k, torch.float8_e4m3fn, torch.bfloat16)

csrc/trtllm_gemm_runner.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,20 @@ struct TrtllmGenGemmRunnerOptions {
4343
int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K,
4444
const gemm::gemm::GemmInterface& interface) {
4545
static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO =
46-
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_TN_transOut_"
46+
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_rM_TN_"
47+
"transOut_"
4748
"noShflA_dsFp8_schedP2x2x1x3_sm100f";
4849

4950
static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO =
50-
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_"
51+
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_"
5152
"transOut_noShflA_dsFp8_schedS_sm100f";
5253

5354
static constexpr const char* KERNEL_NAME_LARGE_N =
54-
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_"
55+
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_"
5556
"transOut_noShflA_dsFp8_schedP2x2x1x3_sm100f";
5657

5758
static constexpr const char* KERNEL_NAME_DEFAULT =
58-
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_TN_"
59+
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_rM_TN_"
5960
"transOut_noShflA_dsFp8_schedS_sm100f";
6061

6162
double const n_k_ratio = static_cast<double>(N) / static_cast<double>(K);

0 commit comments

Comments
 (0)