Skip to content

Commit

Permalink
roofline estimator: simplify (#1783)
Browse files Browse the repository at this point in the history
* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Feb 27, 2025
1 parent cd69415 commit f478692
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 96 deletions.
164 changes: 82 additions & 82 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

"""
This is a script to estimate the benefit from converting a `torch.nn.Linear`
layer to float8, by estimating the difference in e2e GPU kernel time between:
layer to float8 given a single saturated GPU, by estimating the difference
in e2e GPU kernel time between:
1. bf16 gemms in fwd and bwd, and
2. float8 gemms in fwd and bwd, and float8 overhead
The gemm times are estimated either from direct measurements via benchmarks,
or with a roofline estimation based on TOPS and peak compute bandwidth of an
NVIDIA H100.
NVIDIA H100 or B200.
The float8 overhead times are estimated by counting memory reads and writes
based on the specified float8 scaling, and estimating that we can achieve
Expand All @@ -31,12 +32,10 @@
input_t @ grad_output = grad_weight
KxM @ MxN => KxN
2. we properly model the worst-case of the current torch.compile limitations regarding
float8 scaling
3. assume for float8 activations/gradients that torch.compile will fuse to the
2. assume for float8 activations/gradients that torch.compile will fuse to the
preceding op. Note that this is not always true in practice.
4. assume no AC (TODO model it)
5. assume no float8 all-gather (TODO model it)
3. assume no AC (TODO model it)
4. assume no float8 all-gather (TODO model it)
"""

import copy
Expand Down Expand Up @@ -164,68 +163,60 @@ def do_matmul(A, B):

def run(
outfile: str,
gemm_time_strategy: str = "benchmarks",
model_torch_compile_limitations: bool = False,
do_benchmarks: bool = True,
shape_gen_name: str = "square",
gemm_cache_filename: Optional[str] = None,
n_limit: Optional[int] = None,
):
"""
Args:
* `gemm_time_strategy`:
- `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
- `roofline`: use roofline model for gemm times (only accurate for large shapes)
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `square`, or `sweep`
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
"""

print(f"gemm_time_strategy: {gemm_time_strategy}")
print(f"do_benchmarks: {do_benchmarks}")
print(f"shape_gen_name: {shape_gen_name}")

assert gemm_time_strategy in (
"benchmarks",
"roofline",
), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'"

M, K, N = sympy.symbols("M K N")

fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=True,
)
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations=False,
)

if gemm_time_strategy == "roofline":
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print()
else:
print()
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print()

headers = [
"fwd_M",
"fwd_K",
"fwd_N",
# gemm microbenchmarks
"bf16_gemm_s",
"fp8_gemm_s",
# roofline memory overhead estimates
"fp8_oh_estimated",
"fp8_oh_ideal",
# actual e2e measurements
"bf16_s",
"fp8_dyn_s",
"fp8_dyn_sp",
# roofline - gemm time (fwd + bwd, 3 gemms)
"r_bf16_gemm_s",
"r_fp8_gemm_s",
# roofline - fp8 overhead time (by counting reads/writes in the ideal case)
"r_fp8_ovhd_s",
# roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
"r_fp8_gemm_and_ovhd_s",
"r_fp8_gemm_and_ovhd_spdp",
# benchmarks - gemm time (fwd + bwd, 3 gemms)
"b_bf16_gemm_s",
"b_fp8_gemm_s",
# benchmarks - e2e LNLinearSigmoid time fwd + bwd
"b_bf16_e2e_s",
"b_fp8_e2e_s",
# note that e2e speedup is not the same as the roofline speedup:
# 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time)
# 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid)
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
# we don't break them out and don't have a roofline for them.
"b_fp8_e2e_spdp",
]
results = []

Expand All @@ -235,7 +226,18 @@ def run(
if n_limit is not None and idx >= n_limit:
break

if gemm_time_strategy == "benchmarks":
# use roofline model to estimate gemm time
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
r_bf16_gemm_time_s = float(
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
r_fp8_gemm_time_s = float(
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# if enabled, also measured observed gemm time
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
if do_benchmarks:
bf16_g1, f8_g1 = get_gemm_times(
M_val, K_val, N_val, True, gemm_cache_filename
)
Expand All @@ -245,60 +247,58 @@ def run(
bf16_g3, f8_g3 = get_gemm_times(
K_val, M_val, N_val, False, gemm_cache_filename
)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = (
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_gemm_time_s = (
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3

fp8_mem_time_dyn_limit_s = (
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)
fp8_mem_time_dyn_nolimit_s = (
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
r_fp8_ovhd_time_s = float(
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
x = torch.randn(
M_val, K_val, dtype=torch.bfloat16, device="cuda"
).requires_grad_()
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
if do_benchmarks:
# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
x = torch.randn(
M_val, K_val, dtype=torch.bfloat16, device="cuda"
).requires_grad_()

# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)
# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)

# get the float8 dynamic scaling gpu kernel time
# get the float8 dynamic scaling gpu kernel time

torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
m_fp8_dyn = torch.compile(m_fp8_dyn)
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
m_fp8_dyn = torch.compile(m_fp8_dyn)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)

results.append(
[
M_val,
K_val,
N_val,
# gemm microbenchmarks
bf16_time_val,
fp8_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
# e2e numbers
bf16_time_actual_s,
fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
# roofline - gemm
r_bf16_gemm_time_s,
r_fp8_gemm_time_s,
# roofline - fp8 overhead
r_fp8_ovhd_time_s,
# roofline - gemm + overhead, and speedup
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
# benchmarks - gemm
b_bf16_gemm_time_s,
b_fp8_gemm_time_s,
# benchmarks - e2e, and speedup
b_bf16_e2e_time_s,
b_fp8_e2e_time_s,
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
]
)

pd.set_option("display.precision", 2)
df = pd.DataFrame(results, columns=headers)
print(df)
df.to_csv(outfile)
Expand Down
15 changes: 1 addition & 14 deletions torchao/testing/float8/roofline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_tensor_memory_traffic_bytes(
dim0,
dim1,
fuse_with_prev=False,
model_torch_compile_limitations=False,
):
# assumes input bf16, output f8
numel = dim0 * dim1
Expand All @@ -75,15 +74,7 @@ def get_tensor_memory_traffic_bytes(
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel

if model_torch_compile_limitations:
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
# has an extra memory read of the input in fp8
# context: https://github.com/pytorch/pytorch/issues/130015
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
else:
tc_adjustment = 0

return kernel_1_rw + kernel_3_rw + tc_adjustment
return kernel_1_rw + kernel_3_rw


def get_gemm_time_sympy(M, K, N, dtype):
Expand All @@ -101,7 +92,6 @@ def get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations: bool = False,
):
specs = get_specs()

Expand All @@ -123,13 +113,11 @@ def get_float8_mem_sympy(
M,
K,
fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations,
)
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
K,
N,
fuse_with_prev=False,
model_torch_compile_limitations=model_torch_compile_limitations,
)
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem

Expand All @@ -140,7 +128,6 @@ def get_float8_mem_sympy(
M,
N,
fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations,
)
# already casted, assuming that we save weight from fw to bw
# TODO: model this if FSDP float8 all-gather is on
Expand Down

0 comments on commit f478692

Please sign in to comment.