|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# |
| 4 | +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). |
| 5 | +# All Rights Reserved. |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +# |
| 19 | + |
| 20 | +import argparse |
| 21 | +import copy |
| 22 | +import itertools |
| 23 | + |
| 24 | +import torch |
| 25 | +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix |
| 26 | +from weight_shapes import WEIGHT_SHAPES |
| 27 | + |
| 28 | +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn |
| 29 | +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked |
| 30 | +from vllm.triton_utils import triton |
| 31 | + |
| 32 | +PROVIDER_CFGS = { |
| 33 | + "torch-bf16": dict(enabled=True), |
| 34 | + "mxfp4": dict(no_a_quant=False, enabled=True), |
| 35 | + "mxfp4-noquant": dict(no_a_quant=True, enabled=True), |
| 36 | +} |
| 37 | + |
| 38 | +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] |
| 39 | + |
| 40 | + |
| 41 | +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): |
| 42 | + return ( |
| 43 | + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) |
| 44 | + * group_size**-0.5 |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def _quant_weight_mxfp4( |
| 49 | + b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str |
| 50 | +): |
| 51 | + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx( |
| 52 | + b, forward_hadamard_matrix, method="abs_max" |
| 53 | + ) |
| 54 | + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton") |
| 55 | + return weight_hf_e2m1, weight_hf_scale_block |
| 56 | + |
| 57 | + |
| 58 | +def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): |
| 59 | + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4( |
| 60 | + b, forward_hadamard_matrix, device |
| 61 | + ) |
| 62 | + alpha = torch.tensor([1.0], device="cuda") |
| 63 | + |
| 64 | + if cfg["no_a_quant"]: |
| 65 | + # Pre-quantize activation |
| 66 | + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( |
| 67 | + a, forward_hadamard_matrix, method="abs_max" |
| 68 | + ) |
| 69 | + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") |
| 70 | + |
| 71 | + def run(): |
| 72 | + return matmul_mxf4_bf16_tn( |
| 73 | + input_hf_e2m1, |
| 74 | + weight_hf_e2m1, |
| 75 | + input_hf_scale_block, |
| 76 | + weight_hf_scale_block, |
| 77 | + alpha, |
| 78 | + ) |
| 79 | + |
| 80 | + return run |
| 81 | + |
| 82 | + # Quantize activation on-the-fly |
| 83 | + def run(): |
| 84 | + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( |
| 85 | + a, forward_hadamard_matrix, method="abs_max" |
| 86 | + ) |
| 87 | + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") |
| 88 | + return matmul_mxf4_bf16_tn( |
| 89 | + input_hf_e2m1, |
| 90 | + weight_hf_e2m1, |
| 91 | + input_hf_scale_block, |
| 92 | + weight_hf_scale_block, |
| 93 | + alpha, |
| 94 | + ) |
| 95 | + |
| 96 | + return run |
| 97 | + |
| 98 | + |
| 99 | +@triton.testing.perf_report( |
| 100 | + triton.testing.Benchmark( |
| 101 | + x_names=["batch_size"], |
| 102 | + x_vals=[ |
| 103 | + 1, |
| 104 | + 4, |
| 105 | + 8, |
| 106 | + 16, |
| 107 | + 32, |
| 108 | + 64, |
| 109 | + 128, |
| 110 | + 256, |
| 111 | + 512, |
| 112 | + 1024, |
| 113 | + 2048, |
| 114 | + 4096, |
| 115 | + 8192, |
| 116 | + 16384, |
| 117 | + 24576, |
| 118 | + 32768, |
| 119 | + ], |
| 120 | + x_log=False, |
| 121 | + line_arg="provider", |
| 122 | + line_vals=_enabled, |
| 123 | + line_names=_enabled, |
| 124 | + ylabel="TFLOP/s (larger is better)", |
| 125 | + plot_name="BF16 vs MXFP4 GEMMs", |
| 126 | + args={}, |
| 127 | + ) |
| 128 | +) |
| 129 | +def benchmark(batch_size, provider, N, K, had_size): |
| 130 | + M = batch_size |
| 131 | + device = "cuda" |
| 132 | + dtype = torch.bfloat16 |
| 133 | + |
| 134 | + a = torch.randn((M, K), device=device, dtype=dtype) |
| 135 | + b = torch.randn((N, K), device=device, dtype=dtype) |
| 136 | + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) |
| 137 | + |
| 138 | + quantiles = [0.5, 0.2, 0.8] |
| 139 | + |
| 140 | + if provider == "torch-bf16": |
| 141 | + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |
| 142 | + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles |
| 143 | + ) |
| 144 | + else: |
| 145 | + cfg = PROVIDER_CFGS[provider] |
| 146 | + run_quant = build_mxfp4_runner( |
| 147 | + cfg, a, b, forward_hadamard_matrix, dtype, device |
| 148 | + ) |
| 149 | + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |
| 150 | + lambda: run_quant(), rep=200, quantiles=quantiles |
| 151 | + ) |
| 152 | + |
| 153 | + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) |
| 154 | + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) |
| 155 | + |
| 156 | + |
| 157 | +def prepare_shapes(args): |
| 158 | + out = [] |
| 159 | + for model, tp_size in itertools.product(args.models, args.tp_sizes): |
| 160 | + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): |
| 161 | + KN[tp_dim] //= tp_size |
| 162 | + KN.append(model) |
| 163 | + out.append(KN) |
| 164 | + return out |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == "__main__": |
| 168 | + parser = argparse.ArgumentParser() |
| 169 | + parser.add_argument( |
| 170 | + "--models", |
| 171 | + nargs="+", |
| 172 | + type=str, |
| 173 | + default=["meta-llama/Llama-3.3-70B-Instruct"], |
| 174 | + choices=list(WEIGHT_SHAPES.keys()), |
| 175 | + ) |
| 176 | + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) |
| 177 | + args = parser.parse_args() |
| 178 | + |
| 179 | + for K, N, model in prepare_shapes(args): |
| 180 | + for had_size in [32, 64, 128]: |
| 181 | + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:") |
| 182 | + benchmark.run( |
| 183 | + print_data=True, |
| 184 | + show_plots=True, |
| 185 | + save_path=f"bench_mxfp4_res_n{N}_k{K}", |
| 186 | + N=N, |
| 187 | + K=K, |
| 188 | + had_size=had_size, |
| 189 | + ) |
| 190 | + |
| 191 | + print("Benchmark finished!") |
0 commit comments