Skip to content

Commit 1f5d105

Browse files
committed
add infernece only roofline
1 parent 596da93 commit 1f5d105

File tree

4 files changed

+76
-21
lines changed

4 files changed

+76
-21
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ aten/build/
3434
aten/src/ATen/Config.h
3535
aten/src/ATen/cuda/CUDAConfig.h
3636
benchmarks/.data
37+
benchmarks/data
3738
caffe2/cpp_test/
3839
dist/
3940
docs/build/

benchmarks/float8/float8_inference_roofline.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,37 +122,53 @@ def run(
122122
shape_gen_name: str = "pow2",
123123
n_limit: Optional[int] = None,
124124
float8_recipe_name: Optional[str] = None,
125+
mx_recipe_name: Optional[str] = None,
126+
nvfp4_recipe_name: Optional[str] = None,
125127
):
126128
"""
127129
Args:
128130
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
129131
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
130132
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
133+
* `float8_recipe_name (optional)`: float8 quantization recipe
134+
* `mx_recipe_name (optional)`: MX format recipe
135+
* `nvfp4_recipe_name (optional)`: NVFP4 format recipe
131136
"""
132137

133-
assert float8_recipe_name is not None, "unsupported"
138+
recipe_count = sum(
139+
x is not None for x in [float8_recipe_name, mx_recipe_name, nvfp4_recipe_name]
140+
)
141+
142+
assert recipe_count <= 1, "Only one recipe type can be specified at a time"
143+
144+
if recipe_count == 0:
145+
float8_recipe_name = "tensorwise"
134146

135147
print(f"GPU: {torch.cuda.get_device_name(0)}")
136148
print(f"torch version: {torch.__version__}")
137149
print(f"torchao version: {torchao.__version__}")
138150
print(f"do_benchmarks: {do_benchmarks}")
139151
print(f"shape_gen_name: {shape_gen_name}")
140152
print(f"float8_recipe_name: {float8_recipe_name}")
153+
print(f"mx_recipe_name: {mx_recipe_name}")
154+
print(f"nvfp4_recipe_name: {nvfp4_recipe_name}")
141155

142156
M, K, N = sympy.symbols("M K N")
143157

144158
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
145-
M,
146-
K,
147-
N,
148-
float8_recipe_name,
159+
M, K, N, float8_recipe_name, mx_recipe_name, nvfp4_recipe_name
149160
)
150161
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(
151162
M, K, N, torch.bfloat16, None, None
152163
)
153-
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
154-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, None
155-
)
164+
if nvfp4_recipe_name is not None:
165+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
166+
M, K, N, torch.float4_e2m1fn_x2, float8_recipe_name, nvfp4_recipe_name
167+
)
168+
else:
169+
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
170+
M, K, N, torch.float8_e4m3fn, float8_recipe_name, None
171+
)
156172
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
157173
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
158174
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
@@ -261,6 +277,8 @@ def run(
261277
m_fp8_dyn = torch.compile(m_fp8_dyn)
262278
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
263279

280+
roofline_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
281+
264282
results.append(
265283
[
266284
M_val,
@@ -273,7 +291,7 @@ def run(
273291
r_fp8_ovhd_time_s,
274292
# roofline - gemm + overhead, and speedup
275293
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
276-
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
294+
roofline_speedup,
277295
# benchmarks - gemm
278296
b_bf16_gemm_time_s,
279297
b_fp8_gemm_time_s,

benchmarks/float8/float8_roofline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def run(
214214
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
215215
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
216216
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
217+
* `float8_recipe_name (optional)`: float8 quantization recipe
218+
* `mx_recipe_name (optional)`: MX format recipe
217219
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
218220
"""
219221

torchao/testing/training/roofline_utils.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
BYTES_PER_EL_FLOAT4 = 0.5
1313
BYTES_PER_EL_FLOAT8 = 1
1414
BYTES_PER_EL_BF16 = 2
15+
BYTES_PER_EL_FLOAT8_E8M0 = 1
16+
BYTES_PER_EL_FLOAT32 = 4
1517

1618
gpu_name_to_specs = {
1719
"NVIDIA H100": {
@@ -241,7 +243,7 @@ def get_individual_gemm_time_sympy(
241243
elif dtype is torch.float4_e2m1fn_x2:
242244
peak_tops = specs["fp4_peak_tops"]
243245
else:
244-
assert False, "unsupported"
246+
assert False, f"unsupported dtype: {dtype}"
245247
compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
246248

247249
# memory bound
@@ -274,7 +276,7 @@ def get_individual_gemm_time_sympy(
274276
elif dtype is torch.float4_e2m1fn_x2:
275277
bytes_rw = num_reads * BYTES_PER_EL_FLOAT4 + num_writes * BYTES_PER_EL_BF16
276278
else:
277-
assert False, "unsupported"
279+
assert False, f"unsupported dtype: {dtype}"
278280
mem_gemm_time_s = (
279281
bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
280282
)
@@ -376,27 +378,56 @@ def get_inference_tensor_memory_traffic_ovhd_s(
376378
dim1,
377379
tensor_role: str,
378380
float8_recipe_name: Optional[str],
381+
mx_recipe_name: Optional[str],
379382
fuse_with_prev=False,
380383
) -> List[Union[sympy.Symbol, float]]:
381384
"""
382385
Inference version of `get_tensor_memory_traffic_ovhd_s`.
383386
The only thing happening here is we quantize the activation.
384387
"""
385-
assert float8_recipe_name == "rowwise", "unsupported"
386388
assert fuse_with_prev is False, "unsupported"
389+
assert tensor_role == "input", "inference only quantizes input activations"
387390

388391
# assumes input bf16, output f8
389392
numel = dim0 * dim1
390393

391394
res_bytes = None
392395

393-
assert tensor_role == "input"
394-
# x_bf16 = ...
395-
# kernel 1: x_bf16 -> x_fp8
396-
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
397-
res_bytes = [
398-
kernel_1_rw,
399-
]
396+
if float8_recipe_name == "tensorwise":
397+
# x_bf16 = ...
398+
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
399+
# kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs
400+
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
401+
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
402+
kernel_1_rw = BYTES_PER_EL_BF16 * numel
403+
# kernel 3: read in bf16, write in float8
404+
kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
405+
res_bytes = [kernel_1_rw, kernel_3_rw]
406+
407+
elif float8_recipe_name == "rowwise":
408+
# x_bf16 = ...
409+
# kernel 1: x_bf16 -> x_fp8 (with per-row scaling)
410+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
411+
# add in the bytes for scale writes
412+
kernel_1_rw += BYTES_PER_EL_FLOAT32 * dim0
413+
res_bytes = [kernel_1_rw]
414+
415+
elif mx_recipe_name in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil"):
416+
# x_bf16 = ...
417+
# kernel 1: x_bf16 -> x_mxfp8 (block-wise scaling for inference)
418+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
419+
# add in the bytes for scale writes
420+
kernel_1_rw += BYTES_PER_EL_FLOAT8_E8M0 * dim0 * (dim1 // 32)
421+
res_bytes = [kernel_1_rw]
422+
423+
else:
424+
# For NVFP4, assume minimal overhead since it's primarily a compute format
425+
# x_bf16 = ...
426+
# kernel 1: x_bf16 -> x_nvfp4 (per-tensor scaling for inference)
427+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT4 * numel
428+
# add minimal scaling overhead (per-tensor scale)
429+
kernel_1_rw += BYTES_PER_EL_FLOAT32 # single scale factor
430+
res_bytes = [kernel_1_rw]
400431

401432
# convert from bytes to seconds
402433
res_s = [
@@ -415,6 +446,8 @@ def get_inference_float8_mem_sympy(
415446
K,
416447
N,
417448
float8_recipe_name: Optional[str],
449+
mx_recipe_name: Optional[str] = None,
450+
nvfp4_recipe_name: Optional[str] = None,
418451
gpu_name: Optional[str] = None,
419452
):
420453
specs = get_specs(gpu_name)
@@ -426,6 +459,7 @@ def get_inference_float8_mem_sympy(
426459
K,
427460
tensor_role="input",
428461
float8_recipe_name=float8_recipe_name,
462+
mx_recipe_name=mx_recipe_name,
429463
fuse_with_prev=False,
430464
)
431465
res = sum([*fwd_fp8_input_mem])
@@ -438,9 +472,9 @@ def get_inference_gemm_time_sympy(
438472
N: sympy.Symbol,
439473
dtype,
440474
float8_recipe_name: Optional[str],
441-
gpu_name: Optional[str],
475+
nvfp4_recipe_name: Optional[str] = None,
476+
gpu_name: Optional[str] = None,
442477
):
443-
assert float8_recipe_name == "rowwise" or float8_recipe_name is None, "unsupported"
444478
# note: this function is currently not super accurate for small shapes:
445479
# when M,K,N <= 1k,1k,1k it undercounts by around 2x
446480
gemm_output_time_s = get_individual_gemm_time_sympy(M, K, N, dtype, None, gpu_name)

0 commit comments

Comments
 (0)