11# SPDX-License-Identifier: Apache-2.0
2- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
32import argparse
43import copy
54import itertools
1110from vllm ._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
1211from vllm .triton_utils import triton
1312
13+ PROVIDER_CFGS = {
14+ "torch-bf16" : dict (enabled = True ),
15+ "fp8-tensor-w-token-a" : dict (
16+ w = "tensor" , a = "token" , no_a_quant = False , enabled = False
17+ ),
18+ "fp8-tensor-w-tensor-a" : dict (
19+ w = "tensor" , a = "tensor" , no_a_quant = False , enabled = True
20+ ),
21+ "fp8-channel-w-token-a" : dict (
22+ w = "channel" , a = "token" , no_a_quant = False , enabled = True
23+ ),
24+ "fp8-channel-w-tensor-a" : dict (
25+ w = "channel" , a = "tensor" , no_a_quant = False , enabled = False
26+ ),
27+ "fp8-tensor-w-token-a-noquant" : dict (
28+ w = "tensor" , a = "token" , no_a_quant = True , enabled = False
29+ ),
30+ "fp8-tensor-w-tensor-a-noquant" : dict (
31+ w = "tensor" , a = "tensor" , no_a_quant = True , enabled = True
32+ ),
33+ "fp8-channel-w-token-a-noquant" : dict (
34+ w = "channel" , a = "token" , no_a_quant = True , enabled = True
35+ ),
36+ "fp8-channel-w-tensor-a-noquant" : dict (
37+ w = "channel" , a = "tensor" , no_a_quant = True , enabled = False
38+ ),
39+ }
40+
41+ _enabled = [k for k , v in PROVIDER_CFGS .items () if v ["enabled" ]]
42+
43+
44+ def _quant_weight_fp8 (b : torch .Tensor , w_type : str , device : str ):
45+ if w_type == "tensor" :
46+ scale_b = torch .ones (1 , device = device , dtype = torch .float32 )
47+ b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
48+ else :
49+ b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , use_per_token_if_dynamic = True )
50+ return b_fp8 .t (), scale_b_fp8
51+
52+
53+ def build_fp8_runner (cfg , a , b , dtype , device ):
54+ b_fp8 , scale_b_fp8 = _quant_weight_fp8 (b , cfg ["w" ], device )
55+
56+ scale_a_const = (
57+ torch .ones (1 , device = device , dtype = torch .float32 )
58+ if cfg ["a" ] == "tensor"
59+ else None
60+ )
61+
62+ if cfg ["no_a_quant" ]:
63+ if cfg ["a" ] == "tensor" :
64+ a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a_const )
65+ else :
66+ a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , use_per_token_if_dynamic = True )
67+
68+ def run ():
69+ return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
70+
71+ return run
72+
73+ if cfg ["a" ] == "tensor" :
74+
75+ def run ():
76+ a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a_const )
77+ return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
78+
79+ else :
80+
81+ def run ():
82+ a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , use_per_token_if_dynamic = True )
83+ return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
84+
85+ return run
86+
1487
1588@triton .testing .perf_report (
1689 triton .testing .Benchmark (
1790 x_names = ["batch_size" ],
1891 x_vals = [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 , 16384 ],
1992 x_log = False ,
2093 line_arg = "provider" ,
21- line_vals = [
22- "torch-bf16" ,
23- # "fp8-tensor-w-token-a",
24- "fp8-tensor-w-tensor-a" ,
25- "fp8-channel-w-token-a" ,
26- # "fp8-channel-w-tensor-a",
27- # "fp8-tensor-w-token-a-noquant",
28- "fp8-tensor-w-tensor-a-noquant" ,
29- "fp8-channel-w-token-a-noquant" ,
30- # "fp8-channel-w-tensor-a-noquant",
31- ],
32- line_names = [
33- "torch-bf16" ,
34- # "fp8-tensor-w-token-a",
35- "fp8-tensor-w-tensor-a" ,
36- "fp8-channel-w-token-a" ,
37- # "fp8-channel-w-tensor-a",
38- # "fp8-tensor-w-token-a-noquant",
39- "fp8-tensor-w-tensor-a-noquant" ,
40- "fp8-channel-w-token-a-noquant" ,
41- # "fp8-channel-w-tensor-a-noquant",
42- ],
94+ line_vals = _enabled ,
95+ line_names = _enabled ,
4396 ylabel = "TFLOP/s (larger is better)" ,
4497 plot_name = "BF16 vs FP8 GEMMs" ,
4598 args = {},
@@ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K):
50103 device = "cuda"
51104 dtype = torch .bfloat16
52105
53- # Create input tensors
54106 a = torch .randn ((M , K ), device = device , dtype = dtype )
55107 b = torch .randn ((N , K ), device = device , dtype = dtype )
56108
57109 quantiles = [0.5 , 0.2 , 0.8 ]
58110
59- if "torch-bf16" in provider :
111+ if provider == "torch-bf16" :
60112 ms , min_ms , max_ms = triton .testing .do_bench_cudagraph (
61113 lambda : torch .nn .functional .linear (a , b ), quantiles = quantiles
62114 )
63-
64- elif "fp8" in provider :
65- # Weights are always quantized ahead of time
66- if "noquant" in provider :
67- # For no quantization, we just measure the GEMM
68- if "tensor-w-token-a" in provider :
69- # Dynamic per-token quant for A, per-tensor quant for B
70- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b )
71- assert scale_b_fp8 .numel () == 1
72- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (
73- a , use_per_token_if_dynamic = True
74- )
75-
76- def run_quant ():
77- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
78-
79- elif "tensor-w-tensor-a" in provider :
80- # Static per-tensor quantization with fixed scales
81- # for both A and B
82- scale_a = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
83- scale_b = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
84- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
85- assert scale_b_fp8 .numel () == 1
86- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a )
87-
88- def run_quant ():
89- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
90-
91- elif "channel-w-token-a" in provider :
92- # Static per-channel quantization for weights, per-token
93- # quant for A
94- scale_b = torch .tensor ((N ,), device = device , dtype = torch .float32 )
95- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
96- scale_b_fp8 = scale_b_fp8 .expand (N ).contiguous ()
97- assert scale_b_fp8 .numel () == N
98- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (
99- a , use_per_token_if_dynamic = True
100- )
101-
102- def run_quant ():
103- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
104-
105- elif "channel-w-tensor-a" in provider :
106- # Static per-channel quantization for weights, per-tensor
107- # quant for A
108- scale_a = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
109- scale_b = torch .tensor ((N ,), device = device , dtype = torch .float32 )
110- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
111- scale_b_fp8 = scale_b_fp8 .expand (N ).contiguous ()
112- assert scale_b_fp8 .numel () == N
113- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a )
114-
115- def run_quant ():
116- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
117-
118- else :
119- # In these cases, we quantize the activations during the GEMM call
120- if "tensor-w-token-a" in provider :
121- # Dynamic per-token quant for A, per-tensor quant for B
122- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b )
123- assert scale_b_fp8 .numel () == 1
124-
125- def run_quant ():
126- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (
127- a , use_per_token_if_dynamic = True
128- )
129- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
130-
131- elif "tensor-w-tensor-a" in provider :
132- # Static per-tensor quantization with fixed scales
133- # for both A and B
134- scale_a = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
135- scale_b = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
136- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
137- assert scale_b_fp8 .numel () == 1
138-
139- def run_quant ():
140- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a )
141- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
142-
143- elif "channel-w-token-a" in provider :
144- # Static per-channel quantization for weights, per-token
145- # quant for A
146- scale_b = torch .tensor ((N ,), device = device , dtype = torch .float32 )
147- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
148- scale_b_fp8 = scale_b_fp8 .expand (N ).contiguous ()
149- assert scale_b_fp8 .numel () == N
150-
151- def run_quant ():
152- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (
153- a , use_per_token_if_dynamic = True
154- )
155- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
156-
157- elif "channel-w-tensor-a" in provider :
158- # Static per-channel quantization for weights, per-tensor
159- # quant for A
160- scale_a = torch .tensor ([1.0 ], device = device , dtype = torch .float32 )
161- scale_b = torch .tensor ((N ,), device = device , dtype = torch .float32 )
162- b_fp8 , scale_b_fp8 = vllm_scaled_fp8_quant (b , scale_b )
163- scale_b_fp8 = scale_b_fp8 .expand (N ).contiguous ()
164- assert scale_b_fp8 .numel () == N
165-
166- def run_quant ():
167- a_fp8 , scale_a_fp8 = vllm_scaled_fp8_quant (a , scale_a )
168- return vllm_scaled_mm (a_fp8 , b_fp8 , scale_a_fp8 , scale_b_fp8 , dtype )
169-
170- b_fp8 = b_fp8 .t ()
171-
115+ else :
116+ cfg = PROVIDER_CFGS [provider ]
117+ run_quant = build_fp8_runner (cfg , a , b , dtype , device )
172118 ms , min_ms , max_ms = triton .testing .do_bench_cudagraph (
173119 lambda : run_quant (), quantiles = quantiles
174120 )
175121
176- # Calculate TFLOP/s, two flops per multiply-add
177- tflops = lambda ms : (2 * M * N * K ) * 1e-12 / (ms * 1e-3 )
178- return tflops (ms ), tflops (max_ms ), tflops (min_ms )
122+ to_tflops = lambda t_ms : (2 * M * N * K ) * 1e-12 / (t_ms * 1e-3 )
123+ return to_tflops (ms ), to_tflops (max_ms ), to_tflops (min_ms )
179124
180125
181126def prepare_shapes (args ):
182- KN_model_names = []
183- models_tps = list (itertools .product (args .models , args .tp_sizes ))
184- for model , tp_size in models_tps :
185- assert model in WEIGHT_SHAPES
186- for KN , tp_split_dim in copy .deepcopy (WEIGHT_SHAPES [model ]):
187- KN [tp_split_dim ] = KN [tp_split_dim ] // tp_size
127+ out = []
128+ for model , tp_size in itertools .product (args .models , args .tp_sizes ):
129+ for KN , tp_dim in copy .deepcopy (WEIGHT_SHAPES [model ]):
130+ KN [tp_dim ] //= tp_size
188131 KN .append (model )
189- KN_model_names .append (KN )
190- return KN_model_names
132+ out .append (KN )
133+ return out
191134
192135
193136if __name__ == "__main__" :
@@ -197,21 +140,13 @@ def prepare_shapes(args):
197140 nargs = "+" ,
198141 type = str ,
199142 default = ["meta-llama/Llama-3.1-8B-Instruct" ],
200- choices = [* WEIGHT_SHAPES .keys ()],
201- help = "List of models to benchmark" ,
202- )
203- parser .add_argument (
204- "--tp-sizes" ,
205- nargs = "+" ,
206- type = int ,
207- default = [1 ],
208- help = "List of tensor parallel sizes" ,
143+ choices = list (WEIGHT_SHAPES .keys ()),
209144 )
145+ parser .add_argument ("--tp-sizes" , nargs = "+" , type = int , default = [1 ])
210146 args = parser .parse_args ()
211147
212- KN_model_names = prepare_shapes (args )
213- for K , N , model_name in KN_model_names :
214- print (f"{ model_name } , N={ N } K={ K } , BF16 vs FP8 GEMMs TFLOP/s:" )
148+ for K , N , model in prepare_shapes (args ):
149+ print (f"{ model } , N={ N } K={ K } , BF16 vs FP8 GEMMs TFLOP/s:" )
215150 benchmark .run (
216151 print_data = True ,
217152 show_plots = True ,
0 commit comments