Skip to content

Commit 3d330c4

Browse files
authored
[Benchmark] Refactor benchmark script for fp8 & int8 (#19627)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 0b73736 commit 3d330c4

File tree

2 files changed

+184
-280
lines changed

2 files changed

+184
-280
lines changed

benchmarks/kernels/bench_fp8_gemm.py

Lines changed: 92 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
32
import argparse
43
import copy
54
import itertools
@@ -11,35 +10,89 @@
1110
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
1211
from vllm.triton_utils import triton
1312

13+
PROVIDER_CFGS = {
14+
"torch-bf16": dict(enabled=True),
15+
"fp8-tensor-w-token-a": dict(
16+
w="tensor", a="token", no_a_quant=False, enabled=False
17+
),
18+
"fp8-tensor-w-tensor-a": dict(
19+
w="tensor", a="tensor", no_a_quant=False, enabled=True
20+
),
21+
"fp8-channel-w-token-a": dict(
22+
w="channel", a="token", no_a_quant=False, enabled=True
23+
),
24+
"fp8-channel-w-tensor-a": dict(
25+
w="channel", a="tensor", no_a_quant=False, enabled=False
26+
),
27+
"fp8-tensor-w-token-a-noquant": dict(
28+
w="tensor", a="token", no_a_quant=True, enabled=False
29+
),
30+
"fp8-tensor-w-tensor-a-noquant": dict(
31+
w="tensor", a="tensor", no_a_quant=True, enabled=True
32+
),
33+
"fp8-channel-w-token-a-noquant": dict(
34+
w="channel", a="token", no_a_quant=True, enabled=True
35+
),
36+
"fp8-channel-w-tensor-a-noquant": dict(
37+
w="channel", a="tensor", no_a_quant=True, enabled=False
38+
),
39+
}
40+
41+
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
42+
43+
44+
def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str):
45+
if w_type == "tensor":
46+
scale_b = torch.ones(1, device=device, dtype=torch.float32)
47+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
48+
else:
49+
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True)
50+
return b_fp8.t(), scale_b_fp8
51+
52+
53+
def build_fp8_runner(cfg, a, b, dtype, device):
54+
b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device)
55+
56+
scale_a_const = (
57+
torch.ones(1, device=device, dtype=torch.float32)
58+
if cfg["a"] == "tensor"
59+
else None
60+
)
61+
62+
if cfg["no_a_quant"]:
63+
if cfg["a"] == "tensor":
64+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
65+
else:
66+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
67+
68+
def run():
69+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
70+
71+
return run
72+
73+
if cfg["a"] == "tensor":
74+
75+
def run():
76+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
77+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
78+
79+
else:
80+
81+
def run():
82+
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
83+
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
84+
85+
return run
86+
1487

1588
@triton.testing.perf_report(
1689
triton.testing.Benchmark(
1790
x_names=["batch_size"],
1891
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
1992
x_log=False,
2093
line_arg="provider",
21-
line_vals=[
22-
"torch-bf16",
23-
# "fp8-tensor-w-token-a",
24-
"fp8-tensor-w-tensor-a",
25-
"fp8-channel-w-token-a",
26-
# "fp8-channel-w-tensor-a",
27-
# "fp8-tensor-w-token-a-noquant",
28-
"fp8-tensor-w-tensor-a-noquant",
29-
"fp8-channel-w-token-a-noquant",
30-
# "fp8-channel-w-tensor-a-noquant",
31-
],
32-
line_names=[
33-
"torch-bf16",
34-
# "fp8-tensor-w-token-a",
35-
"fp8-tensor-w-tensor-a",
36-
"fp8-channel-w-token-a",
37-
# "fp8-channel-w-tensor-a",
38-
# "fp8-tensor-w-token-a-noquant",
39-
"fp8-tensor-w-tensor-a-noquant",
40-
"fp8-channel-w-token-a-noquant",
41-
# "fp8-channel-w-tensor-a-noquant",
42-
],
94+
line_vals=_enabled,
95+
line_names=_enabled,
4396
ylabel="TFLOP/s (larger is better)",
4497
plot_name="BF16 vs FP8 GEMMs",
4598
args={},
@@ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K):
50103
device = "cuda"
51104
dtype = torch.bfloat16
52105

53-
# Create input tensors
54106
a = torch.randn((M, K), device=device, dtype=dtype)
55107
b = torch.randn((N, K), device=device, dtype=dtype)
56108

57109
quantiles = [0.5, 0.2, 0.8]
58110

59-
if "torch-bf16" in provider:
111+
if provider == "torch-bf16":
60112
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
61113
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
62114
)
63-
64-
elif "fp8" in provider:
65-
# Weights are always quantized ahead of time
66-
if "noquant" in provider:
67-
# For no quantization, we just measure the GEMM
68-
if "tensor-w-token-a" in provider:
69-
# Dynamic per-token quant for A, per-tensor quant for B
70-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
71-
assert scale_b_fp8.numel() == 1
72-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
73-
a, use_per_token_if_dynamic=True
74-
)
75-
76-
def run_quant():
77-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
78-
79-
elif "tensor-w-tensor-a" in provider:
80-
# Static per-tensor quantization with fixed scales
81-
# for both A and B
82-
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
83-
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
84-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
85-
assert scale_b_fp8.numel() == 1
86-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
87-
88-
def run_quant():
89-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
90-
91-
elif "channel-w-token-a" in provider:
92-
# Static per-channel quantization for weights, per-token
93-
# quant for A
94-
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
95-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
96-
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
97-
assert scale_b_fp8.numel() == N
98-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
99-
a, use_per_token_if_dynamic=True
100-
)
101-
102-
def run_quant():
103-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
104-
105-
elif "channel-w-tensor-a" in provider:
106-
# Static per-channel quantization for weights, per-tensor
107-
# quant for A
108-
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
109-
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
110-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
111-
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
112-
assert scale_b_fp8.numel() == N
113-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
114-
115-
def run_quant():
116-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
117-
118-
else:
119-
# In these cases, we quantize the activations during the GEMM call
120-
if "tensor-w-token-a" in provider:
121-
# Dynamic per-token quant for A, per-tensor quant for B
122-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b)
123-
assert scale_b_fp8.numel() == 1
124-
125-
def run_quant():
126-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
127-
a, use_per_token_if_dynamic=True
128-
)
129-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
130-
131-
elif "tensor-w-tensor-a" in provider:
132-
# Static per-tensor quantization with fixed scales
133-
# for both A and B
134-
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
135-
scale_b = torch.tensor([1.0], device=device, dtype=torch.float32)
136-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
137-
assert scale_b_fp8.numel() == 1
138-
139-
def run_quant():
140-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
141-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
142-
143-
elif "channel-w-token-a" in provider:
144-
# Static per-channel quantization for weights, per-token
145-
# quant for A
146-
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
147-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
148-
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
149-
assert scale_b_fp8.numel() == N
150-
151-
def run_quant():
152-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(
153-
a, use_per_token_if_dynamic=True
154-
)
155-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
156-
157-
elif "channel-w-tensor-a" in provider:
158-
# Static per-channel quantization for weights, per-tensor
159-
# quant for A
160-
scale_a = torch.tensor([1.0], device=device, dtype=torch.float32)
161-
scale_b = torch.tensor((N,), device=device, dtype=torch.float32)
162-
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
163-
scale_b_fp8 = scale_b_fp8.expand(N).contiguous()
164-
assert scale_b_fp8.numel() == N
165-
166-
def run_quant():
167-
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
168-
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
169-
170-
b_fp8 = b_fp8.t()
171-
115+
else:
116+
cfg = PROVIDER_CFGS[provider]
117+
run_quant = build_fp8_runner(cfg, a, b, dtype, device)
172118
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
173119
lambda: run_quant(), quantiles=quantiles
174120
)
175121

176-
# Calculate TFLOP/s, two flops per multiply-add
177-
tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3)
178-
return tflops(ms), tflops(max_ms), tflops(min_ms)
122+
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
123+
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
179124

180125

181126
def prepare_shapes(args):
182-
KN_model_names = []
183-
models_tps = list(itertools.product(args.models, args.tp_sizes))
184-
for model, tp_size in models_tps:
185-
assert model in WEIGHT_SHAPES
186-
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
187-
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
127+
out = []
128+
for model, tp_size in itertools.product(args.models, args.tp_sizes):
129+
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
130+
KN[tp_dim] //= tp_size
188131
KN.append(model)
189-
KN_model_names.append(KN)
190-
return KN_model_names
132+
out.append(KN)
133+
return out
191134

192135

193136
if __name__ == "__main__":
@@ -197,21 +140,13 @@ def prepare_shapes(args):
197140
nargs="+",
198141
type=str,
199142
default=["meta-llama/Llama-3.1-8B-Instruct"],
200-
choices=[*WEIGHT_SHAPES.keys()],
201-
help="List of models to benchmark",
202-
)
203-
parser.add_argument(
204-
"--tp-sizes",
205-
nargs="+",
206-
type=int,
207-
default=[1],
208-
help="List of tensor parallel sizes",
143+
choices=list(WEIGHT_SHAPES.keys()),
209144
)
145+
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
210146
args = parser.parse_args()
211147

212-
KN_model_names = prepare_shapes(args)
213-
for K, N, model_name in KN_model_names:
214-
print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
148+
for K, N, model in prepare_shapes(args):
149+
print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
215150
benchmark.run(
216151
print_data=True,
217152
show_plots=True,

0 commit comments

Comments
 (0)