Commit bbb57ad
authored
feat: trtrllm-gen global scaled FP8 GEMMs (#1829)
In low latency context, it is not uncommon to encounter memory bandwidth
bound GEMMs with a tiny leading dimension M. These cases are currently
not addressed as efficiently as they could by library implementations.
To fill this gap, I propose to expose generated GEMM kernels optimized
for small batch sizes, which saturate memory bandwidth to a higher
degree.
The main challenge in doing so is that these GEMMs expect the weight
tensor (second operand) to be pre-processed into a layout more amenable
to maximizing memory bandwidth saturation. As such it is not practical
to expose them under the same API as the other GEMMs, as they are not
interchangeable without changing the caller's implementation. I have
tentatively exposed these GEMMs as "flavored" GEMMs, by contrast with
the more "vanilla" GEMMs currently available.
Summary of the changes:
- Added cpp runner to be jitted for these new GEMMs:
`csrc/trtllm_flavored_gemm_runner.cu`
- A separate `flashinfer/trtllm_flavored_gemm.py` file containing the
Python interface of the new GEMMs
- Some stylistic refactoring of the autotuner done while understanding
how it works
- Tests
- Benchmarks
- Some other minor cleanups along the way. Note I will undo the
extraction of `fp8_utils.py` as the implementations of `to_fp8` differ
between the places I extracted it for
Next step:
I will add more kernels for larger batch sizes. This is required because
the weight matrix shuffling commits the user to this interface.
Therefore, they also need efficient kernels for larger batches, which
they will encounter for prefills for example, when not doing disagg.
Benchmarking results on GB200:
```
m=1 n=2560 k=16384 9.65 TFLOPs/s over 0.008694 ms, 4.83 TB/s
m=1 n=2560 k=32768 11.34 TFLOPs/s over 0.014797 ms, 5.67 TB/s
m=1 n=5120 k=16384 15.10 TFLOPs/s over 0.011110 ms, 7.55 TB/s
m=1 n=5120 k=32768 12.21 TFLOPs/s over 0.027491 ms, 6.10 TB/s
m=1 n=8192 k=16384 11.75 TFLOPs/s over 0.022851 ms, 5.87 TB/s
m=1 n=8192 k=32768 13.06 TFLOPs/s over 0.041114 ms, 6.53 TB/s
m=2 n=2560 k=16384 18.38 TFLOPs/s over 0.009130 ms, 4.60 TB/s
m=2 n=2560 k=32768 21.21 TFLOPs/s over 0.015821 ms, 5.31 TB/s
m=2 n=5120 k=16384 30.21 TFLOPs/s over 0.011107 ms, 7.56 TB/s
m=2 n=5120 k=32768 24.41 TFLOPs/s over 0.027491 ms, 6.11 TB/s
m=2 n=8192 k=16384 23.43 TFLOPs/s over 0.022912 ms, 5.86 TB/s
m=2 n=8192 k=32768 26.15 TFLOPs/s over 0.041056 ms, 6.54 TB/s
m=4 n=2560 k=16384 36.22 TFLOPs/s over 0.009264 ms, 4.54 TB/s
m=4 n=2560 k=32768 43.55 TFLOPs/s over 0.015408 ms, 5.45 TB/s
m=4 n=5120 k=16384 60.40 TFLOPs/s over 0.011110 ms, 7.56 TB/s
m=4 n=5120 k=32768 48.82 TFLOPs/s over 0.027494 ms, 6.11 TB/s
m=4 n=8192 k=16384 46.71 TFLOPs/s over 0.022989 ms, 5.84 TB/s
m=4 n=8192 k=32768 52.10 TFLOPs/s over 0.041216 ms, 6.52 TB/s
m=8 n=2560 k=16384 72.47 TFLOPs/s over 0.009261 ms, 4.55 TB/s
m=8 n=2560 k=32768 84.84 TFLOPs/s over 0.015821 ms, 5.32 TB/s
m=8 n=5120 k=16384 120.84 TFLOPs/s over 0.011107 ms, 7.57 TB/s
m=8 n=5120 k=32768 97.37 TFLOPs/s over 0.027568 ms, 6.10 TB/s
m=8 n=8192 k=16384 93.41 TFLOPs/s over 0.022989 ms, 5.85 TB/s
m=8 n=8192 k=32768 104.21 TFLOPs/s over 0.041216 ms, 6.52 TB/s
m=16 n=2560 k=16384 138.70 TFLOPs/s over 0.009677 ms, 4.37 TB/s
m=16 n=2560 k=32768 174.22 TFLOPs/s over 0.015408 ms, 5.48 TB/s
m=16 n=5120 k=16384 231.03 TFLOPs/s over 0.011619 ms, 7.26 TB/s
m=16 n=5120 k=32768 190.13 TFLOPs/s over 0.028237 ms, 5.97 TB/s
m=16 n=8192 k=16384 180.96 TFLOPs/s over 0.023734 ms, 5.68 TB/s
m=16 n=8192 k=32768 205.52 TFLOPs/s over 0.041795 ms, 6.44 TB/s
m=32 n=2560 k=16384 260.92 TFLOPs/s over 0.010288 ms, 4.14 TB/s
m=32 n=2560 k=32768 322.64 TFLOPs/s over 0.016640 ms, 5.11 TB/s
m=32 n=5120 k=16384 421.01 TFLOPs/s over 0.012752 ms, 6.65 TB/s
m=32 n=5120 k=32768 371.18 TFLOPs/s over 0.028928 ms, 5.85 TB/s
m=32 n=8192 k=16384 348.80 TFLOPs/s over 0.024627 ms, 5.49 TB/s
m=32 n=8192 k=32768 400.89 TFLOPs/s over 0.042854 ms, 6.30 TB/s
m=64 n=2560 k=16384 466.29 TFLOPs/s over 0.011514 ms, 3.76 TB/s
m=64 n=2560 k=32768 458.96 TFLOPs/s over 0.023395 ms, 3.69 TB/s
m=64 n=5120 k=16384 673.11 TFLOPs/s over 0.015952 ms, 5.37 TB/s
m=64 n=5120 k=32768 679.79 TFLOPs/s over 0.031590 ms, 5.40 TB/s
m=64 n=8192 k=16384 648.00 TFLOPs/s over 0.026512 ms, 5.14 TB/s
m=64 n=8192 k=32768 766.41 TFLOPs/s over 0.044832 ms, 6.06 TB/s
```1 parent fd03820 commit bbb57ad
File tree
25 files changed
+1074
-90
lines changed- benchmarks
- csrc
- flashinfer
- fused_moe
- jit
- gemm
- include/flashinfer/trtllm/gemm/trtllmGen_gemm_export
- trtllm/gen
- tests
- gemm
- moe
25 files changed
+1074
-90
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
46 | | - | |
| 46 | + | |
| 47 | + | |
47 | 48 | | |
48 | 49 | | |
49 | 50 | | |
50 | | - | |
| 51 | + | |
51 | 52 | | |
52 | 53 | | |
53 | 54 | | |
54 | | - | |
| 55 | + | |
55 | 56 | | |
56 | 57 | | |
57 | 58 | | |
58 | | - | |
| 59 | + | |
59 | 60 | | |
60 | 61 | | |
61 | 62 | | |
| |||
0 commit comments