Skip to content

Commit 96ad65b

Browse files
LopezCastroRobertoBlackSamorezmgoin
authored
[Transform] [Quantization] Add QuTLASS support to vLLM (#24440)
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Signed-off-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Andrei Panferov <andrei@panferov.org> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent 8d2b8c0 commit 96ad65b

File tree

12 files changed

+1848
-1
lines changed

12 files changed

+1848
-1
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,8 @@ steps:
834834
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
835835
- pytest -v -s tests/kernels/moe/test_flashinfer.py
836836
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
837+
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
838+
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
837839

838840
- label: Blackwell GPT-OSS Eval
839841
timeout_in_minutes: 60

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,7 @@ endif()
10071007
# For CUDA we also build and ship some external projects.
10081008
if (VLLM_GPU_LANG STREQUAL "CUDA")
10091009
include(cmake/external_projects/flashmla.cmake)
1010+
include(cmake/external_projects/qutlass.cmake)
10101011

10111012
# vllm-flash-attn should be last as it overwrites some CMake functions
10121013
include(cmake/external_projects/vllm_flash_attn.cmake)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
#
4+
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
5+
# All Rights Reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
import argparse
21+
import copy
22+
import itertools
23+
24+
import torch
25+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
26+
from weight_shapes import WEIGHT_SHAPES
27+
28+
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
29+
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
30+
from vllm.triton_utils import triton
31+
32+
PROVIDER_CFGS = {
33+
"torch-bf16": dict(enabled=True),
34+
"mxfp4": dict(no_a_quant=False, enabled=True),
35+
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
36+
}
37+
38+
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
39+
40+
41+
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
42+
return (
43+
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
44+
* group_size**-0.5
45+
)
46+
47+
48+
def _quant_weight_mxfp4(
49+
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
50+
):
51+
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
52+
b, forward_hadamard_matrix, method="abs_max"
53+
)
54+
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
55+
return weight_hf_e2m1, weight_hf_scale_block
56+
57+
58+
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
59+
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
60+
b, forward_hadamard_matrix, device
61+
)
62+
alpha = torch.tensor([1.0], device="cuda")
63+
64+
if cfg["no_a_quant"]:
65+
# Pre-quantize activation
66+
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
67+
a, forward_hadamard_matrix, method="abs_max"
68+
)
69+
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
70+
71+
def run():
72+
return matmul_mxf4_bf16_tn(
73+
input_hf_e2m1,
74+
weight_hf_e2m1,
75+
input_hf_scale_block,
76+
weight_hf_scale_block,
77+
alpha,
78+
)
79+
80+
return run
81+
82+
# Quantize activation on-the-fly
83+
def run():
84+
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
85+
a, forward_hadamard_matrix, method="abs_max"
86+
)
87+
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
88+
return matmul_mxf4_bf16_tn(
89+
input_hf_e2m1,
90+
weight_hf_e2m1,
91+
input_hf_scale_block,
92+
weight_hf_scale_block,
93+
alpha,
94+
)
95+
96+
return run
97+
98+
99+
@triton.testing.perf_report(
100+
triton.testing.Benchmark(
101+
x_names=["batch_size"],
102+
x_vals=[
103+
1,
104+
4,
105+
8,
106+
16,
107+
32,
108+
64,
109+
128,
110+
256,
111+
512,
112+
1024,
113+
2048,
114+
4096,
115+
8192,
116+
16384,
117+
24576,
118+
32768,
119+
],
120+
x_log=False,
121+
line_arg="provider",
122+
line_vals=_enabled,
123+
line_names=_enabled,
124+
ylabel="TFLOP/s (larger is better)",
125+
plot_name="BF16 vs MXFP4 GEMMs",
126+
args={},
127+
)
128+
)
129+
def benchmark(batch_size, provider, N, K, had_size):
130+
M = batch_size
131+
device = "cuda"
132+
dtype = torch.bfloat16
133+
134+
a = torch.randn((M, K), device=device, dtype=dtype)
135+
b = torch.randn((N, K), device=device, dtype=dtype)
136+
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
137+
138+
quantiles = [0.5, 0.2, 0.8]
139+
140+
if provider == "torch-bf16":
141+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
142+
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
143+
)
144+
else:
145+
cfg = PROVIDER_CFGS[provider]
146+
run_quant = build_mxfp4_runner(
147+
cfg, a, b, forward_hadamard_matrix, dtype, device
148+
)
149+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
150+
lambda: run_quant(), rep=200, quantiles=quantiles
151+
)
152+
153+
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
154+
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
155+
156+
157+
def prepare_shapes(args):
158+
out = []
159+
for model, tp_size in itertools.product(args.models, args.tp_sizes):
160+
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
161+
KN[tp_dim] //= tp_size
162+
KN.append(model)
163+
out.append(KN)
164+
return out
165+
166+
167+
if __name__ == "__main__":
168+
parser = argparse.ArgumentParser()
169+
parser.add_argument(
170+
"--models",
171+
nargs="+",
172+
type=str,
173+
default=["meta-llama/Llama-3.3-70B-Instruct"],
174+
choices=list(WEIGHT_SHAPES.keys()),
175+
)
176+
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
177+
args = parser.parse_args()
178+
179+
for K, N, model in prepare_shapes(args):
180+
for had_size in [32, 64, 128]:
181+
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
182+
benchmark.run(
183+
print_data=True,
184+
show_plots=True,
185+
save_path=f"bench_mxfp4_res_n{N}_k{K}",
186+
N=N,
187+
K=K,
188+
had_size=had_size,
189+
)
190+
191+
print("Benchmark finished!")

0 commit comments

Comments
 (0)