Skip to content

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Aug 20, 2025

Purpose

The weight.shape % 128 != 0 limitation of the SM90 CUTLASS implementation was a product of our custom implementation and works now on modern CUTLASS. We can completely remove csrc/cutlass_extensions/gemm/ and migrate to using CUTLASS directly.

Caveats with the current support in CUTLASS is that there is a restriction of M%4==0 and weight block scales layout when using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum for the MainloopScheduler. We work around this by:

This means we can replace the triton kernel when running the DeepSeekV3 layers kv_a_proj_with_mqa with shape [576, 7168] and fused_qkv_a_proj with shape [2112, 7168] for better performance on Hopper.

Additional tuning should come next to help with smaller M

Test Plan

Added kernel test case that works. Will manually test DeepSeek and profile

Test Result

GSM8k for DSV3:

vllm serve deepseek-ai/DeepSeek-V3 -tp 8 --no-enable-prefix-caching
python tests/evals/gsm8k/gsm8k_eval.py

# main
Accuracy: 0.948, Questions per second: 11.423
Accuracy: 0.951, Questions per second: 11.600
Accuracy: 0.944, Questions per second: 11.324

# PR
Accuracy: 0.947, Questions per second: 11.216
Accuracy: 0.943, Questions per second: 11.189
Accuracy: 0.948, Questions per second: 11.331

Profiling result with dummy deepseek using llm = LLM(model="deepseek-ai/DeepSeek-R1", hf_overrides={'num_hidden_layers': 4}, load_format="dummy"). You can see the _w8a8_block_fp8_matmul triton kernel in before and none of that in the after, with e2e perf being better.

python examples/offline_inference/simple_profiling.py

# Before (notice _w8a8_block_fp8_matmul triton kernel)
Processed prompts: 100%|███████████████| 4/4 [00:00<00:00, 38.75it/s, est. speed input: 252.11 toks/s, output: 620.56 toks/s]
(EngineCore_0 pid=1785027) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1785027)                                                    Name      Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
(EngineCore_0 pid=1785027) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1785027) void cutlass::device_kernel<vllm::cutlass_3x_gemm_fp...       16.578ms        35.45%      16.578ms      59.207us           280  
(EngineCore_0 pid=1785027)                                                aten::mm       10.615ms        22.70%      11.239ms     661.133us            17  
(EngineCore_0 pid=1785027)                       nvjet_tst_512x8_64x3_2x1_v_bz_TNT       10.615ms        22.70%      10.615ms     624.384us            17  
(EngineCore_0 pid=1785027)                                  _w8a8_block_fp8_matmul        5.273ms        11.28%       5.273ms      77.544us            68  
(EngineCore_0 pid=1785027)                            _flashmla_C::fwd_kvcache_mla        3.292ms         7.04%       3.447ms      53.867us            64  
(EngineCore_0 pid=1785027)                                        fused_moe_kernel        2.344ms         5.01%       2.344ms      68.940us            34  
(EngineCore_0 pid=1785027) void flash_fwd_splitkv_mla_kernel<Traits<cutlass::bf...        2.091ms         4.47%       2.091ms      32.668us            64  
(EngineCore_0 pid=1785027) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1785027) Self CPU time total: 68.859ms
(EngineCore_0 pid=1785027) Self CUDA time total: 46.764ms

# After (only cutlass_3x_gemm_fp8)
Processed prompts: 100%|██████████████| 4/4 [00:00<00:00, 46.75it/s, est. speed input: 304.99 toks/s, output: 750.68 toks/s]
(EngineCore_0 pid=1801324) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1801324)                                                    Name      Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
(EngineCore_0 pid=1801324) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1801324) void cutlass::device_kernel<vllm::cutlass_3x_gemm_fp...       17.804ms        40.34%      17.804ms      51.162us           348  
(EngineCore_0 pid=1801324)                                                aten::mm       10.609ms        24.04%      11.234ms     660.841us            17  
(EngineCore_0 pid=1801324)                       nvjet_tst_512x8_64x3_2x1_v_bz_TNT       10.609ms        24.04%      10.609ms     624.064us            17  
(EngineCore_0 pid=1801324)                            _flashmla_C::fwd_kvcache_mla        3.312ms         7.50%       3.450ms      53.905us            64  
(EngineCore_0 pid=1801324)                                        fused_moe_kernel        2.351ms         5.33%       2.351ms      69.141us            34  
(EngineCore_0 pid=1801324) void flash_fwd_splitkv_mla_kernel<Traits<cutlass::bf...        2.098ms         4.75%       2.098ms      32.780us            64  
(EngineCore_0 pid=1801324) void per_token_group_quant_8bit_kernel<c10::BFloat16...        1.318ms         2.99%       1.318ms       3.788us           348  
(EngineCore_0 pid=1801324) -------------------------------------------------------   ------------  ------------  ------------  ------------  ------------  
(EngineCore_0 pid=1801324) Self CPU time total: 59.841ms
(EngineCore_0 pid=1801324) Self CUDA time total: 44.133ms
pytest -s -v "tests/kernels/quantization/test_block_fp8.py" -k "test_w8a8_block_fp8_cutlass_matmul"

tests/kernels/quantization/test_block_fp8.py::test_w8a8_block_fp8_cutlass_matmul PASSED

========================================================= 1 passed, 324 deselected in 2.59s ==========================================================
Benchmarking DeepSeek-V3, N=576 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 09-10 15:51:13 [fp8_utils.py:572] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs:
    batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0          1.0    1.118786               0.380050                0.245376
1         16.0   16.608691               5.750164                3.864525
2         64.0   62.026452              22.229612               15.007353
3        128.0   90.654547              33.863944               29.114532
4        256.0  134.696843              61.239278               53.284166
5        512.0  246.738788             112.315111               93.573581
6       1024.0  428.570759             161.635820              140.708858
7       2048.0  400.060736             200.281287              176.241851
8       4096.0  680.503931             225.348428              186.796922
9       8192.0  703.208631             244.516189              208.212890
10     16384.0  714.670788             250.868051              212.857175

Benchmarking DeepSeek-V3, N=2112 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 09-10 15:51:23 [fp8_utils.py:572] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs:
    batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0          1.0    3.059417               1.298436                0.893065
1         16.0   49.595406              20.755163               13.940020
2         64.0  163.467729              63.663467               55.792628
3        128.0  275.740510             118.310905              106.531540
4        256.0  401.863069             219.266445              192.465136
5        512.0  566.262626             305.159322              322.256922
6       1024.0  619.771938             366.839772              350.649020
7       2048.0  648.192575             396.032348              426.640953
8       4096.0  715.781529             402.982983              456.931446
9       8192.0  666.385367             406.511360              489.655688
10     16384.0  700.477731             409.090253              495.912795

Benchmarking DeepSeek-V3, N=24576 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 09-10 15:51:35 [fp8_utils.py:572] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs:
    batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0          1.0    3.019477               5.238588                4.078537
1         16.0   47.620245              84.375540               66.410762
2         64.0  187.282541             311.351460              272.827468
3        128.0  362.585807             408.295602              414.538562
4        256.0  643.247781             541.951724              522.809898
5        512.0  701.025029             567.283560              740.877703
6       1024.0  707.898854             577.334585              910.299344
7       2048.0  719.766032             572.063239              991.638151
8       4096.0  743.031776             600.573577             1010.715671
9       8192.0  732.942224             597.843862              996.655238
10     16384.0  748.777833             614.269851             1007.447214

Benchmarking DeepSeek-V3, N=32768 K=512
TFLOP/s comparison (block_size=(128, 128)):
INFO 09-10 15:51:47 [fp8_utils.py:572] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs:
    batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0          1.0    4.323625               3.726614                1.788261
1         16.0   70.206510              56.238394               30.672470
2         64.0  235.769937             200.112552              116.541109
3        128.0  305.891240             259.234134              158.486482
4        256.0  404.678436             280.259083              260.498704
5        512.0  470.984729             331.790846              365.325882
6       1024.0  527.756313             365.203479              463.643993
7       2048.0  549.251957             384.610034              569.167970
8       4096.0  573.356640             400.088473              600.554598
9       8192.0  582.517690             400.704071              638.472628
10     16384.0  591.007268             402.869576              640.536501

Benchmarking DeepSeek-V3, N=7168 K=16384
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:03:27 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.888605               3.611561                3.623948
1        16.0   45.120732              59.070079               54.894814
2        17.0   47.809703              51.185873               58.224845
3       512.0  656.558465             471.430313              558.842296
4      4096.0  742.291137             550.661025              916.429456
5      4097.0  728.547168             534.487990              883.800739
6     16381.0  772.366941             561.867271              902.320768
7     16384.0  742.071294             556.479367              861.439349

Benchmarking DeepSeek-V3, N=7168 K=18432
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:03:39 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.931162               3.631018                3.382976
1        16.0   45.103797              60.586643               56.725651
2        17.0   45.011694              44.709704               59.895291
3       512.0  653.996522             482.437006              549.407967
4      4096.0  742.398952             565.791399              924.489801
5      4097.0  730.168173             555.717302              923.881416
6     16381.0  773.192588             567.381955              894.148619
7     16384.0  742.440539             576.830125              873.733185

Benchmarking DeepSeek-V3, N=36864 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:03:51 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.915775               5.234165                4.004754
1        16.0   47.192357              88.392643               66.766938
2        17.0   50.276074              52.968448               70.536995
3       512.0  691.947588             572.249816              831.031351
4      4096.0  737.411114             620.200793             1079.029047
5      4097.0  743.744182             590.550357             1068.085453
6     16381.0  753.781731             589.830428             1049.508527
7     16384.0  747.349007             607.414113             1047.982039

Benchmarking DeepSeek-V3, N=24576 K=1536
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:04:02 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.709021               4.537728                2.820492
1        16.0   42.156864              79.662652               45.066369
2        17.0   44.450920              52.542601               47.904630
3       512.0  621.833777             441.276932              581.468714
4      4096.0  682.771839             524.414303              869.092606
5      4097.0  692.102159             495.779940              865.241290
6     16381.0  696.742236             533.151614              916.656840
7     16384.0  702.604973             537.093305              919.904056

Benchmarking DeepSeek-V3, N=12288 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:04:12 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.929049               4.189704                3.439284
1        16.0   45.804490              70.815990               53.838165
2        17.0   48.936398              48.391307               57.359623
3       512.0  697.968468             543.380499              692.536385
4      4096.0  731.913702             563.780243              943.747089
5      4097.0  737.210540             570.238741              963.468131
6     16381.0  738.202011             584.424675              913.025680
7     16384.0  747.327747             584.678850              938.286089

Benchmarking DeepSeek-V3, N=4096 K=7168
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:04:24 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    2.362971               1.989233                1.695491
1        16.0   36.424417              39.270706               27.112955
2        17.0   38.525986              34.472193               28.825433
3       512.0  673.659006             470.326090              393.773307
4      4096.0  723.197031             519.817275              706.946748
5      4097.0  712.154485             530.765948              680.583713
6     16381.0  746.091756             518.502791              730.153178
7     16384.0  714.929963             524.408922              747.778728

Benchmarking DeepSeek-V3, N=7168 K=2048
TFLOP/s comparison (block_size=(128, 128)):
INFO 08-28 14:04:33 [fp8_utils.py:569] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
BF16 vs W8A8 Block FP8 GEMMs (Triton vs CUTLASS):
   batch_size  torch-bf16  w8a8-block-fp8-triton  w8a8-block-fp8-cutlass
0         1.0    3.265824               2.405914                1.452183
1        16.0   63.774683              40.926749               25.460154
2        17.0   59.791679              35.218846               27.076483
3       512.0  612.787461             391.179224              383.144430
4      4096.0  686.795100             478.744753              746.921514
5      4097.0  656.573362             473.951384              741.106331
6     16381.0  718.258476             506.946802              798.761266
7     16384.0  691.668333             510.792224              790.806028

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: mgoin <mgoin64@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly enables the SM90 CUTLASS Block FP8 kernel for weight shapes that are not divisible by 128, addressing a previous limitation and a TODO in the code. The logic in vllm/model_executor/layers/quantization/utils/fp8_utils.py has been simplified and corrected to use a more general condition for SM90+ GPUs, which will improve performance for models like DeepSeekV3. The newly added test case in tests/kernels/quantization/test_block_fp8.py effectively validates that the CUTLASS kernel works as expected with non-aligned shapes. The changes are well-implemented and look good.

@mgoin mgoin marked this pull request as draft August 20, 2025 19:22
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin changed the title Allow SM90 CUTLASS Block FP8 when weight.shape % 128 != 0 Use upstream CUTLASS for SM90 Block FP8 kernel Aug 20, 2025
@mgoin mgoin changed the title Use upstream CUTLASS for SM90 Block FP8 kernel [WIP] Use upstream CUTLASS for SM90 Block FP8 kernel Aug 27, 2025
@mgoin mgoin marked this pull request as ready for review August 27, 2025 15:59
@mgoin mgoin requested a review from LucasWilkinson as a code owner August 27, 2025 15:59
@mergify mergify bot added the ci/build label Aug 27, 2025
mgoin added 8 commits August 28, 2025 13:08
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mergify mergify bot added the performance Performance-related issues label Aug 29, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 30, 2025
@mgoin mgoin changed the title [WIP] Use upstream CUTLASS for SM90 Block FP8 kernel [Perf] Use upstream CUTLASS for SM90 Block FP8 kernel Aug 30, 2025
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing; thank you for doing this!

@simon-mo simon-mo merged commit c3aea10 into vllm-project:main Sep 11, 2025
71 of 73 checks passed
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…3280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
…3280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…3280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…3280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…3280)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants