Skip to content

Commit

Permalink
float8 gemm benchmarks: add option for gpu time (#666)
Browse files Browse the repository at this point in the history
Summary:

Adds the option to run bf16 gemm vs float8 gemm benchmark with
gpu kernel time instead of e2e time.  This is useful for upcoming
work on float8 roofline estimation.

Test Plan:

```
python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
// results: https://gist.github.com/vkuzo/2f09182b795fa4156939ad707966d6c3
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Aug 13, 2024
1 parent 88a263a commit 174e630
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
45 changes: 40 additions & 5 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from torch.profiler import profile, ProfilerActivity, record_function

from utils import get_name_to_shapes_iter
from utils import (
get_name_to_shapes_iter,
profiler_output_to_filtered_time_by_kernel_name,
)

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
Expand Down Expand Up @@ -43,8 +47,38 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
return measurement.mean


def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
time_sec = benchmark_fn_in_sec(f, *args, **kwargs)
def get_gpu_kernel_gemm_time(f, *args, **kwargs):
# warmup
f(*args, **kwargs)
n_iter = 5
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for idx in range(n_iter):
f(*args, **kwargs)
data = profiler_output_to_filtered_time_by_kernel_name(prof, n_iter, num_leaf_tensors=0)
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
assert len(data) == 1
if "aten::mm" in data:
return data["aten::mm"] / 1e6 / n_iter
elif "aten::_scaled_mm" in data:
return data["aten::_scaled_mm"] / 1e6 / n_iter
else:
raise AssertionError("unexpected format of data")


def do_benchmarks(
tops,
peak_tops,
use_gpu_kernel_time,
f,
*args,
**kwargs,
):
if use_gpu_kernel_time:
# just the gemm GPU kernel
time_sec = get_gpu_kernel_gemm_time(f, *args, **kwargs)
else:
# e2e time including kernel launch overhead
time_sec = benchmark_fn_in_sec(f, *args, **kwargs)
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return time_sec, tops_sec, pct_top_peak
Expand All @@ -58,6 +92,7 @@ def run(
M: Optional[int] = None,
K: Optional[int] = None,
N: Optional[int] = None,
use_gpu_kernel_time: bool = False,
):
device = "cuda"

Expand All @@ -79,7 +114,7 @@ def run(
A = torch.randn(M, K, device=device, dtype=dtype)
m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False))
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[dtype], m_ref, A
tops, dtype_to_peak_tops[dtype], use_gpu_kernel_time, m_ref, A
)
print(
f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}"
Expand All @@ -101,7 +136,7 @@ def do_matmul(A, B):
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], do_matmul, A, B
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
)
print(
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def float8_forw_backward_wrapper(x):

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "":
if os.environ.get("TORCHINDUCTOR_PROFILE", "") == "":
context = nullcontext()
f = None
else:
Expand Down

0 comments on commit 174e630

Please sign in to comment.