Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into compile_guard
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 8, 2024
2 parents 759895c + e11201a commit 2f0aef4
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 85 deletions.
154 changes: 79 additions & 75 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
kernel_name_to_category,
parse_bw_and_kernel_name,
profiler_output_to_gpu_time_for_key,
profiler_output_to_time_by_kernel_name,
profiler_output_to_filtered_time_by_kernel_name,
)

# don't truncate long kernel names
Expand Down Expand Up @@ -312,85 +312,89 @@ def float8_forw_backward_wrapper(x):
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
f = io.StringIO()
with redirect_stdout(f):
# warm up
for _ in range(1):
try:
with redirect_stdout(f):
# warm up
for _ in range(1):
if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

num_leaf_tensors = 1 + len(list(m_ref.parameters()))

if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

if dtype_filter != "float8":
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
)

if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
)

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

# print the redirected stdout back to regular stdout
print(f.getvalue())
finally:
# print the redirected stdout back to regular stdout
print(f.getvalue())

# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
Expand Down
64 changes: 60 additions & 4 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,44 @@
from typing import Optional


def profiler_output_to_time_by_kernel_name(prof):
def profiler_output_to_filtered_time_by_kernel_name(
prof,
num_iter: int,
num_leaf_tensors: int,
):
"""
Input: a profiler with captured events.
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name
Input:
* `prof`: a profiler with captured events
* `num_iter`: number of iterations used to capture `prof`
* `num_leaf_tensors`: number of leaf tensors to accumulate gradients to
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name,
with the microbenchmark overhead filtered out
Currently assumes that `prof` captured events from a microbenchmark which was
set up as follows:
#
# Forward pass
#
# Expected GPU kernel overhead: none
y = func(...)
# Convenient way to set up the backward pass without caring about shapes
y_sum = y.sum()
# Expected GPU kernel overhead:
# * the call to `sum`
#
# Backward pass
#
y_sum.backward()
# Expected GPU kernel overhead:
# * the call to `aten.fill_` to put a tensor with a single 1.0 value as the input to the backward
# * the call to `aten.copy_` to fill the first `grad_output` tensor with 1.0
# * the call to `aten.add_` to accumulate grads, once per leaf tensor
Note that if there are user_annotations in the captured events, `torch.profiler`
will include their time in the total GPU time displayed at the bottom of
Expand All @@ -23,13 +57,35 @@ def profiler_output_to_time_by_kernel_name(prof):
thresh = 1e-10
kernel_name_to_gpu_time_us = collections.defaultdict(float)
for e in key_averages:

# manually filter top-level CPU events with attributed CUDA time
# example CPU event row:
# example CPU event row from printing `key_averages`:
# aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1
# and it maps to this CUDA event:
# sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1
if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh):
continue

# manually filter expected microbenchmarking overhead, in order of execution
if e.key == 'aten::sum':
# forward pass sum
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::fill_':
# filling the forward pass sum with 1.0
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::copy_':
# copying 1.0 from grad_out of `sum` to grad_out of next op
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::add_':
# accumulating gradients into leaf tensors
assert e.count == (num_iter * num_leaf_tensors), f'unexpected number of iter for {e.key}'
continue
elif e.key == 'cudaDeviceSynchronize':
continue

kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
return kernel_name_to_gpu_time_us

Expand Down
1 change: 1 addition & 0 deletions test/float8/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def setup_mock(self):
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
)
@unittest.skip("Pytorch needs a fix to ensure codegen maintains stride order")
def test_fp8_export(self):
export_model = FeedForward().to("cuda")
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
Expand Down
14 changes: 8 additions & 6 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
from .quant_primitives import (
safe_int_mm,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_5
from torchao.quantization.utils import quantize_activation_per_token_absmax

import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
except:
from torch._inductor.runtime.runtime_utils import do_bench

__all__ = [
"AutoQuantizableLinearWeight",
Expand Down Expand Up @@ -227,10 +223,16 @@ def do_autoquant_bench(op, *args, **kwargs):
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
if TORCH_VERSION_AFTER_2_3:
if TORCH_VERSION_AFTER_2_5:
from torch._inductor.runtime.benchmarking import benchmarker
res = benchmarker.benchmark_gpu(
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
)
elif TORCH_VERSION_AFTER_2_3:
from torch._inductor.runtime.runtime_utils import do_bench_gpu
res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
else:
from torch._inductor.utils import do_bench
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
return res

Expand Down

0 comments on commit 2f0aef4

Please sign in to comment.