Skip to content

Commit f49239c

Browse files
authored
Benchmark script for fp8 vs bf16 gemm (#17126)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 2dbe8c0 commit f49239c

File tree

2 files changed

+268
-0
lines changed

2 files changed

+268
-0
lines changed
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import argparse
3+
import copy
4+
import itertools
5+
6+
import torch
7+
import triton
8+
from weight_shapes import WEIGHT_SHAPES
9+
10+
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
11+
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
12+
13+
14+
@triton.testing.perf_report(
15+
triton.testing.Benchmark(
16+
x_names=["batch_size"],
17+
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
18+
x_log=False,
19+
line_arg="provider",
20+
line_vals=[
21+
"torch-bf16",
22+
# "fp8-tensor-w-token-a",
23+
"fp8-tensor-w-tensor-a",
24+
"fp8-channel-w-token-a",
25+
# "fp8-channel-w-tensor-a",
26+
# "fp8-tensor-w-token-a-noquant",
27+
"fp8-tensor-w-tensor-a-noquant",
28+
"fp8-channel-w-token-a-noquant",
29+
# "fp8-channel-w-tensor-a-noquant",
30+
],
31+
line_names=[
32+
"torch-bf16",
33+
# "fp8-tensor-w-token-a",
34+
"fp8-tensor-w-tensor-a",
35+
"fp8-channel-w-token-a",
36+
# "fp8-channel-w-tensor-a",
37+
# "fp8-tensor-w-token-a-noquant",
38+
"fp8-tensor-w-tensor-a-noquant",
39+
"fp8-channel-w-token-a-noquant",
40+
# "fp8-channel-w-tensor-a-noquant",
41+
],
42+
ylabel="TFLOP/s (larger is better)",
43+
plot_name="BF16 vs FP8 GEMMs",
44+
args={},
45+
)
46+
)
47+
def benchmark(batch_size, provider, N, K):
48+
M = batch_size
49+
device = "cuda"
50+
dtype = torch.bfloat16
51+
52+
# Create input tensors
53+
a = torch.randn((M, K), device=device, dtype=dtype)
54+
b = torch.randn((N, K), device=device, dtype=dtype)
55+
56+
quantiles = [0.5, 0.2, 0.8]
57+
58+
if "torch-bf16" in provider:
59+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
60+
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
61+
)
62+
63+
elif "fp8" in provider:
64+
# Weights are always quantized ahead of time
65+
if "noquant" in provider:
66+
# For no quantization, we just measure the GEMM
67+
if "tensor-w-token-a" in provider:
68+
# Dynamic per-token quant for A, per-tensor quant for B
69+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
70+
assert scale_b_fp8.numel() == 1
71+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
72+
a, use_per_token_if_dynamic=True
73+
)
74+
75+
def run_quant():
76+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
77+
78+
elif "tensor-w-tensor-a" in provider:
79+
# Static per-tensor quantization with fixed scales
80+
# for both A and B
81+
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
82+
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
83+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
84+
assert scale_b_fp8.numel() == 1
85+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
86+
87+
def run_quant():
88+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
89+
90+
elif "channel-w-token-a" in provider:
91+
# Static per-channel quantization for weights, per-token
92+
# quant for A
93+
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
94+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
95+
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
96+
assert scale_b_fp8.numel() == N
97+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
98+
a, use_per_token_if_dynamic=True
99+
)
100+
101+
def run_quant():
102+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
103+
104+
elif "channel-w-tensor-a" in provider:
105+
# Static per-channel quantization for weights, per-tensor
106+
# quant for A
107+
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
108+
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
109+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
110+
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
111+
assert scale_b_fp8.numel() == N
112+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
113+
114+
def run_quant():
115+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
116+
117+
else:
118+
# In these cases, we quantize the activations during the GEMM call
119+
if "tensor-w-token-a" in provider:
120+
# Dynamic per-token quant for A, per-tensor quant for B
121+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
122+
assert scale_b_fp8.numel() == 1
123+
124+
def run_quant():
125+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
126+
a, use_per_token_if_dynamic=True
127+
)
128+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
129+
130+
elif "tensor-w-tensor-a" in provider:
131+
# Static per-tensor quantization with fixed scales
132+
# for both A and B
133+
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
134+
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
135+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
136+
assert scale_b_fp8.numel() == 1
137+
138+
def run_quant():
139+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
140+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
141+
142+
elif "channel-w-token-a" in provider:
143+
# Static per-channel quantization for weights, per-token
144+
# quant for A
145+
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
146+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
147+
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
148+
assert scale_b_fp8.numel() == N
149+
150+
def run_quant():
151+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
152+
a, use_per_token_if_dynamic=True
153+
)
154+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
155+
156+
elif "channel-w-tensor-a" in provider:
157+
# Static per-channel quantization for weights, per-tensor
158+
# quant for A
159+
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
160+
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
161+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
162+
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
163+
assert scale_b_fp8.numel() == N
164+
165+
def run_quant():
166+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
167+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
168+
169+
b_fp8 = b_fp8.t()
170+
171+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
172+
lambda: run_quant(), quantiles=quantiles
173+
)
174+
175+
# Calculate TFLOP/s, two flops per multiply-add
176+
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
177+
return tflops(ms), tflops(max_ms), tflops(min_ms)
178+
179+
180+
def prepare_shapes(args):
181+
KN_model_names = []
182+
models_tps = list(itertools.product(args.models, args.tp_sizes))
183+
for model, tp_size in models_tps:
184+
assert model in WEIGHT_SHAPES
185+
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
186+
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
187+
KN.append(model)
188+
KN_model_names.append(KN)
189+
return KN_model_names
190+
191+
192+
if __name__ == "__main__":
193+
parser = argparse.ArgumentParser()
194+
parser.add_argument(
195+
"--models",
196+
nargs="+",
197+
type=str,
198+
default=["meta-llama/Llama-3.1-8B-Instruct"],
199+
choices=[*WEIGHT_SHAPES.keys()],
200+
help="List of models to benchmark",
201+
)
202+
parser.add_argument(
203+
"--tp-sizes",
204+
nargs="+",
205+
type=int,
206+
default=[1],
207+
help="List of tensor parallel sizes",
208+
)
209+
args = parser.parse_args()
210+
211+
KN_model_names = prepare_shapes(args)
212+
for K, N, model_name in KN_model_names:
213+
print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
214+
benchmark.run(
215+
print_data=True,
216+
show_plots=True,
217+
save_path=f"bench_fp8_res_n{N}_k{K}",
218+
N=N,
219+
K=K,
220+
)
221+
222+
print("Benchmark finished!")

benchmarks/kernels/weight_shapes.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,50 @@
4848
([16384, 106496], 1),
4949
([53248, 16384], 0),
5050
],
51+
"meta-llama/Llama-3.1-8B-Instruct": [
52+
([4096, 6144], 1),
53+
([4096, 4096], 0),
54+
([4096, 28672], 1),
55+
([14336, 4096], 0),
56+
],
57+
"meta-llama/Llama-3.3-70B-Instruct": [
58+
([8192, 10240], 1),
59+
([8192, 8192], 0),
60+
([8192, 57344], 1),
61+
([28672, 8192], 0),
62+
],
63+
"mistralai/Mistral-Large-Instruct-2407": [
64+
([12288, 14336], 1),
65+
([12288, 12288], 0),
66+
([12288, 57344], 1),
67+
([28672, 12288], 0),
68+
],
69+
"Qwen/Qwen2.5-7B-Instruct": [
70+
([3584, 4608], 1),
71+
([3584, 3584], 0),
72+
([3584, 37888], 1),
73+
([18944, 3584], 0),
74+
],
75+
"Qwen/Qwen2.5-32B-Instruct": [
76+
([5120, 7168], 1),
77+
([5120, 5120], 0),
78+
([5120, 55296], 1),
79+
([27648, 5120], 0),
80+
],
81+
"Qwen/Qwen2.5-72B-Instruct": [
82+
([8192, 10240], 1),
83+
([8192, 8192], 0),
84+
([8192, 59136], 1),
85+
([29568, 8192], 0),
86+
],
87+
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
88+
([2048, 3072], 1),
89+
([2048, 4096], 1),
90+
([2048, 2048], 0),
91+
([2048, 576], 0),
92+
([2048, 21888], 1),
93+
([10944, 2048], 0),
94+
([2048, 2816], 1),
95+
([1408, 2048], 0),
96+
],
5197
}

0 commit comments

Comments
 (0)