4848import sympy
4949import torch
5050import torch .nn as nn
51- import torch .utils .benchmark as benchmark
5251import tqdm
5352from torch .profiler import ProfilerActivity , profile
5453from utils import (
5756 profiler_output_to_filtered_time_by_kernel_name ,
5857)
5958
59+ import torchao
6060from torchao .float8 import (
6161 Float8LinearConfig ,
6262 convert_to_float8_training ,
@@ -83,20 +83,6 @@ def forward(self, x):
8383 return x
8484
8585
86- # TODO(next): hook this up
87-
88-
89- def benchmark_fn_in_sec (f , * args , ** kwargs ):
90- # Manual warmup
91- for _ in range (4 ):
92- f (* args , ** kwargs )
93- t0 = benchmark .Timer (
94- stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
95- )
96- measurement = t0 .blocked_autorange ()
97- return measurement .mean
98-
99-
10086def get_gpu_kernel_time (m , x , grad_output ):
10187 # warm up
10288 for _ in range (2 ):
@@ -232,6 +218,8 @@ def run(
232218 float8_recipe_name = "tensorwise"
233219
234220 print (f"GPU: { torch .cuda .get_device_name (0 )} " )
221+ print (f"torch version: { torch .__version__ } " )
222+ print (f"torchao version: { torchao .__version__ } " )
235223 print (f"do_benchmarks: { do_benchmarks } " )
236224 print (f"shape_gen_name: { shape_gen_name } " )
237225 print (f"float8_recipe_name: { float8_recipe_name } " )
0 commit comments