@@ -49,8 +49,8 @@ class Experiment:
4949
5050
5151def get_configs () -> List [ExperimentConfig ]:
52- input_shapes = [(2 ** 8 , 4096 ), ( 2 ** 12 , 4096 ), ( 2 ** 16 , 4096 )]
53- n_groups_list = [4 , 8 , 16 ]
52+ input_shapes = [(16640 , 5120 )] # (Mg, K)
53+ n_groups_list = [16 , 128 ]
5454 high_precision_dtypes = [torch .bfloat16 ]
5555 configs = []
5656 for input_shape , n_groups , high_precision_dtype in itertools .product (
@@ -129,6 +129,7 @@ def run_triton(
129129
130130 # bench torch
131131 compiled_run_torch = torch .compile (run_torch )
132+ warmup (compiled_run_torch , input_row_major , input_col_major , offs )
132133 torch_time_us = benchmark_cuda_function_in_microseconds (
133134 compiled_run_torch , input_row_major , input_col_major , offs
134135 )
@@ -152,6 +153,7 @@ def print_results(experiments: List[Experiment]):
152153 "high_precision_dtype" ,
153154 "torch_time_us" ,
154155 "triton_time_us" ,
156+ "triton_speedup" ,
155157 ]
156158 rows = []
157159 for experiment in experiments :
@@ -165,6 +167,7 @@ def print_results(experiments: List[Experiment]):
165167 experiment .config .high_precision_dtype ,
166168 experiment .result .torch_time_us ,
167169 experiment .result .triton_time_us ,
170+ f"{ experiment .result .torch_time_us / experiment .result .triton_time_us :.2f} x" ,
168171 ]
169172 )
170173 print (tabulate (rows , headers = headers ))
0 commit comments