@@ -22,7 +22,7 @@ def __init__(
2222 prune_configs_by : Dict = None ,
2323 warmup = 25 ,
2424 rep = 100 ,
25- report = False ,
25+ print_autotune_stats = False ,
2626 ):
2727 """
2828 :param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -78,7 +78,7 @@ def _post_hook(args):
7878 self .fn = fn
7979 self .num_warmups = warmup
8080 self .num_reps = rep
81- self .report = report
81+ self .print_autotune_stats = print_autotune_stats
8282
8383 def _bench (self , * args , config , ** meta ):
8484 # check for conflicts, i.e. meta-parameters both provided
@@ -111,7 +111,6 @@ def kernel_call():
111111
112112 def run (self , * args , ** kwargs ):
113113 self .nargs = dict (zip (self .arg_names , args ))
114- autotune_start = time .time ()
115114 used_cached_result = True
116115 if len (self .configs ) > 1 :
117116 all_args = {** self .nargs , ** kwargs }
@@ -139,10 +138,9 @@ def run(self, *args, **kwargs):
139138 else :
140139 config = self .configs [0 ]
141140 self .best_config = config
142- if self .report and not used_cached_result :
143- autotune_stop = time .time ()
141+ if self .print_autotune_stats and not used_cached_result :
144142 print (
145- f"Autotuner for function { self .fn } finished after { autotune_stop - autotune_start :.2f} s; best config selected: { self .best_config } ;"
143+ f"Autotuner for function { self .fn } finished after { self . bench_time :.2f} s; best config selected: { self .best_config } ;"
146144 )
147145 full_nargs = {** self .nargs , ** kwargs , ** self .best_config .kwargs }
148146 if config .pre_hook is not None :
@@ -235,7 +233,7 @@ def __str__(self):
235233
236234
237235def autotune (configs , key , prune_configs_by = None , reset_to_zero = None , restore_value = None , warmup = 25 , rep = 100 ,
238- report = False ):
236+ print_autotune_stats = False ):
239237 """
240238 Decorator for auto-tuning a :code:`triton.jit`'d function.
241239
@@ -272,13 +270,13 @@ def kernel(x_ptr, x_size, **META):
272270 :type warmup: int
273271 :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
274272 :type rep: int
275- :param report: Flag to enable printing the selected configuration
273+ :param print_autotune_stats: If set to true, a log message will be printed after each autotune evaluation containing the benchmark time and the selected configuration.
276274 :type report: bool
277275 """
278276
279277 def decorator (fn ):
280278 return Autotuner (fn , fn .arg_names , configs , key , reset_to_zero , restore_value , prune_configs_by , warmup , rep ,
281- report )
279+ print_autotune_stats )
282280
283281 return decorator
284282
0 commit comments