@@ -22,6 +22,7 @@ def __init__(
2222 prune_configs_by : Dict = None ,
2323 warmup = 25 ,
2424 rep = 100 ,
25+ report = False ,
2526 ):
2627 """
2728 :param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -77,6 +78,7 @@ def _post_hook(args):
7778 self .fn = fn
7879 self .num_warmups = warmup
7980 self .num_reps = rep
81+ self .report = report
8082
8183 def _bench (self , * args , config , ** meta ):
8284 # check for conflicts, i.e. meta-parameters both provided
@@ -109,6 +111,8 @@ def kernel_call():
109111
110112 def run (self , * args , ** kwargs ):
111113 self .nargs = dict (zip (self .arg_names , args ))
114+ autotune_start = time .time ()
115+ used_cached_result = True
112116 if len (self .configs ) > 1 :
113117 all_args = {** self .nargs , ** kwargs }
114118 _args = []
@@ -122,6 +126,7 @@ def run(self, *args, **kwargs):
122126 key = tuple (key )
123127 if key not in self .cache :
124128 # prune configs
129+ used_cached_result = False
125130 pruned_configs = self .prune_configs (kwargs )
126131 bench_start = time .time ()
127132 timings = {config : self ._bench (* args , config = config , ** kwargs ) for config in pruned_configs }
@@ -134,6 +139,11 @@ def run(self, *args, **kwargs):
134139 else :
135140 config = self .configs [0 ]
136141 self .best_config = config
142+ if self .report and not used_cached_result :
143+ autotune_stop = time .time ()
144+ print (
145+ f"Autotuner for function { self .fn } finished after { autotune_stop - autotune_start :.2f} s; best config selected: { self .best_config } ;"
146+ )
137147 full_nargs = {** self .nargs , ** kwargs , ** self .best_config .kwargs }
138148 if config .pre_hook is not None :
139149 config .pre_hook (full_nargs )
@@ -224,7 +234,8 @@ def __str__(self):
224234 return ", " .join (res )
225235
226236
227- def autotune (configs , key , prune_configs_by = None , reset_to_zero = None , restore_value = None , warmup = 25 , rep = 100 ):
237+ def autotune (configs , key , prune_configs_by = None , reset_to_zero = None , restore_value = None , warmup = 25 , rep = 100 ,
238+ report = False ):
228239 """
229240 Decorator for auto-tuning a :code:`triton.jit`'d function.
230241
@@ -261,10 +272,13 @@ def kernel(x_ptr, x_size, **META):
261272 :type warmup: int
262273 :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
263274 :type rep: int
275+ :param report: Flag to enable printing the selected configuration
276+ :type report: bool
264277 """
265278
266279 def decorator (fn ):
267- return Autotuner (fn , fn .arg_names , configs , key , reset_to_zero , restore_value , prune_configs_by , warmup , rep )
280+ return Autotuner (fn , fn .arg_names , configs , key , reset_to_zero , restore_value , prune_configs_by , warmup , rep ,
281+ report )
268282
269283 return decorator
270284
0 commit comments