From 34b24f7d10f50d26c3524e6a21773281feaa5b52 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 8 Aug 2024 06:55:39 -0700 Subject: [PATCH 1/4] float8 profiling script: filter out microbenchmarking overhead (#629) Summary: Our microbenchmarks have a lot of overhead. This PR attempts to get a cleaner measurement of only the kernels in the fwd+bwd, and subtracts the kernels unrelated to fwd+bwd code. This makes the kernel summary tables more reflective of GPU bound real use cases. Test Plan: profiling ln -> linear: ``` python benchmarks/float8/profile_linear_float8.py --dtype_filter both ~/local/tmp --model_type ln_linear ``` new output, note that only kernels relevant to ln and linear are displayed ``` Summary of GPU time by CPU kernel experiment kernel category time_ms pct_gpu_time bw_gpbs 1 0_ref aten::mm 0_gemm 10.153 0.945 None 2 0_ref triton_red_fused_native_layer_norm_native_layer_norm_backward_0 2_other 0.350 0.033 None 0 0_ref triton_red_fused_native_layer_norm_0 2_other 0.241 0.022 None 12 1_float8 aten::_scaled_mm 0_gemm 5.182 0.736 None 16 1_float8 triton_red_fused__scaled_mm__to_copy_clamp_clone_mul_native_layer_norm_native_layer_norm_backwar... 1_f8_overhead 0.813 0.115 None 15 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_reciprocal_view_2 1_f8_overhead 0.302 0.043 None 5 1_float8 triton_red_fused_abs_max_native_layer_norm_0 1_f8_overhead 0.212 0.030 None 10 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_mul_native_layer_norm_view_5 1_f8_overhead 0.177 0.025 None 11 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_native_layer_norm_view_6 1_f8_overhead 0.150 0.021 None 13 1_float8 triton_red_fused_abs_max_0 1_f8_overhead 0.126 0.018 None 7 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.060 0.008 None 3 1_float8 triton_per_fused_copy_max_roll_0 1_f8_overhead 0.005 0.001 None 6 1_float8 triton_red_fused__to_copy_abs_clamp_max_mul_native_layer_norm_reciprocal_1 1_f8_overhead 0.004 0.001 None 4 1_float8 triton_per_fused_copy_max_roll_1 1_f8_overhead 0.003 0.000 None 14 1_float8 triton_per_fused__scaled_mm__to_copy_abs_clamp_clone_max_mul_reciprocal_view_1 1_f8_overhead 0.003 0.000 None 8 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.003 0.000 None 9 1_float8 triton_poi_fused_reciprocal_4 2_other 0.002 0.000 None Float8 amax/scale sync approx ratio of total time: 0.006 Summary of time (ms) by kernel category experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 10.153 5.182 0.510 1.959 1_f8_overhead 0.000 1.858 inf 0.000 2_other 0.591 0.002 0.004 264.393 All 10.743 7.042 0.655 1.526 ``` Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/profile_linear_float8.py | 154 +++++++++++---------- benchmarks/float8/utils.py | 64 ++++++++- 2 files changed, 139 insertions(+), 79 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 3f2047cfc..a0db58179 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -29,7 +29,7 @@ kernel_name_to_category, parse_bw_and_kernel_name, profiler_output_to_gpu_time_for_key, - profiler_output_to_time_by_kernel_name, + profiler_output_to_filtered_time_by_kernel_name, ) # don't truncate long kernel names @@ -312,85 +312,89 @@ def float8_forw_backward_wrapper(x): # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script f = io.StringIO() - with redirect_stdout(f): - # warm up - for _ in range(1): + try: + with redirect_stdout(f): + # warm up + for _ in range(1): + if dtype_filter != "float8": + ref_forw_backward(input_tensor) + if dtype_filter != "bfloat16": + float8_forw_backward_wrapper(input_tensor) + + profile_iters = 5 + ref_times, float8_times = None, None + data = [] + + num_leaf_tensors = 1 + len(list(m_ref.parameters())) + if dtype_filter != "float8": - ref_forw_backward(input_tensor) - if dtype_filter != "bfloat16": - float8_forw_backward_wrapper(input_tensor) - - profile_iters = 5 - ref_times, float8_times = None, None - data = [] - - if dtype_filter != "float8": - # Profile Reference Model - print("profiling ref") - ref_suffix = f"_{model_type}_ref_compile_{compile}.json" - ref_path = profile_path_prefix + ref_suffix - profile_config = ProfileConfig( - ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True - ) - p = profile_function(profile_config, ref_forw_backward, input_tensor) - print(f"saved {ref_path}") - ref_times = profiler_output_to_time_by_kernel_name(p) - total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters - for k, v in ref_times.items(): - v_ms = v / 1e3 / profile_iters - data.append( - [ - "0_ref", - k, - kernel_name_to_category(k), - v_ms, - v_ms / total_time_ms, - None, - ] + # Profile Reference Model + print("profiling ref") + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" + ref_path = profile_path_prefix + ref_suffix + profile_config = ProfileConfig( + ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True ) + p = profile_function(profile_config, ref_forw_backward, input_tensor) + print(f"saved {ref_path}") + ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) + total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters + for k, v in ref_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "0_ref", + k, + kernel_name_to_category(k), + v_ms, + v_ms / total_time_ms, + None, + ] + ) - if dtype_filter != "bfloat16": - # Profile Float8 Model - print("profiling float8") - float8_suffix = ( - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" - ) - float8_path = profile_path_prefix + float8_suffix - profile_config = ProfileConfig( - float8_path, - float8_suffix, - iters=profile_iters, - warmup_iters=2, - sync=True, - ) - p = profile_function( - profile_config, float8_forw_backward_wrapper, input_tensor - ) - print(f"saved {float8_path}") - float8_times = profiler_output_to_time_by_kernel_name(p) - total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters - for k, v in float8_times.items(): - v_ms = v / 1e3 / profile_iters - data.append( - [ - "1_float8", - k, - kernel_name_to_category(k), - v / 1e3 / profile_iters, - v_ms / total_time_ms, - None, - ] + if dtype_filter != "bfloat16": + # Profile Float8 Model + print("profiling float8") + float8_suffix = ( + f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" ) + float8_path = profile_path_prefix + float8_suffix + profile_config = ProfileConfig( + float8_path, + float8_suffix, + iters=profile_iters, + warmup_iters=2, + sync=True, + ) + p = profile_function( + profile_config, float8_forw_backward_wrapper, input_tensor + ) + print(f"saved {float8_path}") + float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors) + total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters + for k, v in float8_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "1_float8", + k, + kernel_name_to_category(k), + v / 1e3 / profile_iters, + v_ms / total_time_ms, + None, + ] + ) + + # get the time spent per user annotation + sync_time_us = profiler_output_to_gpu_time_for_key( + p, "scale_amax_and_scales" + ) + sync_time_ms = sync_time_us / profile_iters / 1e3 + print(f"Sync time ms: {sync_time_ms}") - # get the time spent per user annotation - sync_time_us = profiler_output_to_gpu_time_for_key( - p, "scale_amax_and_scales" - ) - sync_time_ms = sync_time_us / profile_iters / 1e3 - print(f"Sync time ms: {sync_time_ms}") - - # print the redirected stdout back to regular stdout - print(f.getvalue()) + finally: + # print the redirected stdout back to regular stdout + print(f.getvalue()) # populate the triton kernel bandwidth for line in f.getvalue().split("\n"): diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index 106546b37..d0bd0f410 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -9,10 +9,44 @@ from typing import Optional -def profiler_output_to_time_by_kernel_name(prof): +def profiler_output_to_filtered_time_by_kernel_name( + prof, + num_iter: int, + num_leaf_tensors: int, +): """ - Input: a profiler with captured events. - Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name + Input: + * `prof`: a profiler with captured events + * `num_iter`: number of iterations used to capture `prof` + * `num_leaf_tensors`: number of leaf tensors to accumulate gradients to + Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name, + with the microbenchmark overhead filtered out + + Currently assumes that `prof` captured events from a microbenchmark which was + set up as follows: + + # + # Forward pass + # + + # Expected GPU kernel overhead: none + y = func(...) + + # Convenient way to set up the backward pass without caring about shapes + y_sum = y.sum() + + # Expected GPU kernel overhead: + # * the call to `sum` + + # + # Backward pass + # + y_sum.backward() + + # Expected GPU kernel overhead: + # * the call to `aten.fill_` to put a tensor with a single 1.0 value as the input to the backward + # * the call to `aten.copy_` to fill the first `grad_output` tensor with 1.0 + # * the call to `aten.add_` to accumulate grads, once per leaf tensor Note that if there are user_annotations in the captured events, `torch.profiler` will include their time in the total GPU time displayed at the bottom of @@ -23,13 +57,35 @@ def profiler_output_to_time_by_kernel_name(prof): thresh = 1e-10 kernel_name_to_gpu_time_us = collections.defaultdict(float) for e in key_averages: + # manually filter top-level CPU events with attributed CUDA time - # example CPU event row: + # example CPU event row from printing `key_averages`: # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1 # and it maps to this CUDA event: # sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1 if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh): continue + + # manually filter expected microbenchmarking overhead, in order of execution + if e.key == 'aten::sum': + # forward pass sum + assert e.count == num_iter, f'unexpected number of iter for {e.key}' + continue + elif e.key == 'aten::fill_': + # filling the forward pass sum with 1.0 + assert e.count == num_iter, f'unexpected number of iter for {e.key}' + continue + elif e.key == 'aten::copy_': + # copying 1.0 from grad_out of `sum` to grad_out of next op + assert e.count == num_iter, f'unexpected number of iter for {e.key}' + continue + elif e.key == 'aten::add_': + # accumulating gradients into leaf tensors + assert e.count == (num_iter * num_leaf_tensors), f'unexpected number of iter for {e.key}' + continue + elif e.key == 'cudaDeviceSynchronize': + continue + kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total return kernel_name_to_gpu_time_us From 0a3b3288b4eaad501b80a62597e6ccf9fc42b327 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 8 Aug 2024 09:23:16 -0700 Subject: [PATCH 2/4] Fix Inductor bench BC change (#638) * Fix Inductor bench BC change * update * push * pish --- torchao/quantization/autoquant.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index e7d03f3f2..ff515df67 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -15,14 +15,19 @@ from .quant_primitives import ( safe_int_mm, ) -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_5 from torchao.quantization.utils import quantize_activation_per_token_absmax import torch.nn.functional as F + try: from torch._inductor.utils import do_bench -except: - from torch._inductor.runtime.runtime_utils import do_bench +except ImportError: + try: + from torch._inductor.runtime.runtime_utils import do_bench + except ImportError: + from torch._inductor.runtime.benchmarking import benchmarker + do_bench = benchmarker.benchmark __all__ = [ "AutoQuantizableLinearWeight", @@ -227,9 +232,13 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AFTER_2_3: + if TORCH_VERSION_AFTER_2_3 and not TORCH_VERSION_AFTER_2_5: from torch._inductor.runtime.runtime_utils import do_bench_gpu res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + elif TORCH_VERSION_AFTER_2_5 and torch.cuda.is_available(): + from torch._inductor.runtime.benchmarking import benchmarker + do_bench_gpu = benchmarker.benchmark_gpu + res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") else: res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") return res From 5f35645bb75b67dd217c1cf17495ecb4a18b82fa Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Thu, 8 Aug 2024 12:18:16 -0700 Subject: [PATCH 3/4] Skip failing inference test which requires change to PT core (#640) --- test/float8/test_inference_flows.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/float8/test_inference_flows.py b/test/float8/test_inference_flows.py index 735075163..988b44396 100644 --- a/test/float8/test_inference_flows.py +++ b/test/float8/test_inference_flows.py @@ -267,6 +267,7 @@ def setup_mock(self): not torch.cuda.is_available() or not is_H100, "CUDA not available or on non H100 machine", ) + @unittest.skip("Pytorch needs a fix to ensure codegen maintains stride order") def test_fp8_export(self): export_model = FeedForward().to("cuda") quant_config = QuantConfig(ActivationCasting.DYNAMIC) From e11201a62669f582d81cdb33e031a07fb8dfc4f3 Mon Sep 17 00:00:00 2001 From: Nicolas Macchioni Date: Thu, 8 Aug 2024 14:03:35 -0700 Subject: [PATCH 4/4] [BC][internal-first] Cleanup BC fixes (#641) Update autoquant.py --- torchao/quantization/autoquant.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ff515df67..6eee43c51 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -20,15 +20,6 @@ import torch.nn.functional as F -try: - from torch._inductor.utils import do_bench -except ImportError: - try: - from torch._inductor.runtime.runtime_utils import do_bench - except ImportError: - from torch._inductor.runtime.benchmarking import benchmarker - do_bench = benchmarker.benchmark - __all__ = [ "AutoQuantizableLinearWeight", "autoquant", @@ -232,14 +223,16 @@ def do_autoquant_bench(op, *args, **kwargs): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): op(*args, **kwargs) - if TORCH_VERSION_AFTER_2_3 and not TORCH_VERSION_AFTER_2_5: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") - elif TORCH_VERSION_AFTER_2_5 and torch.cuda.is_available(): + if TORCH_VERSION_AFTER_2_5: from torch._inductor.runtime.benchmarking import benchmarker - do_bench_gpu = benchmarker.benchmark_gpu + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + elif TORCH_VERSION_AFTER_2_3: + from torch._inductor.runtime.runtime_utils import do_bench_gpu res = do_bench_gpu(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") else: + from torch._inductor.utils import do_bench res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") return res