Skip to content

Commit 7e94aae

Browse files
committed
[AUTOTUNER] adding simple report flag for autotuner runs
1 parent 5ee38fe commit 7e94aae

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

python/triton/runtime/autotuner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
prune_configs_by: Dict = None,
2323
warmup=25,
2424
rep=100,
25+
report=False,
2526
):
2627
"""
2728
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -77,6 +78,7 @@ def _post_hook(args):
7778
self.fn = fn
7879
self.num_warmups = warmup
7980
self.num_reps = rep
81+
self.report = report
8082

8183
def _bench(self, *args, config, **meta):
8284
# check for conflicts, i.e. meta-parameters both provided
@@ -109,6 +111,8 @@ def kernel_call():
109111

110112
def run(self, *args, **kwargs):
111113
self.nargs = dict(zip(self.arg_names, args))
114+
autotune_start = time.time()
115+
used_cached_result = True
112116
if len(self.configs) > 1:
113117
all_args = {**self.nargs, **kwargs}
114118
_args = []
@@ -122,6 +126,7 @@ def run(self, *args, **kwargs):
122126
key = tuple(key)
123127
if key not in self.cache:
124128
# prune configs
129+
used_cached_result = False
125130
pruned_configs = self.prune_configs(kwargs)
126131
bench_start = time.time()
127132
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
@@ -134,6 +139,11 @@ def run(self, *args, **kwargs):
134139
else:
135140
config = self.configs[0]
136141
self.best_config = config
142+
if self.report and not used_cached_result:
143+
autotune_stop = time.time()
144+
print(
145+
f"Autotuner for function {self.fn} finished after {autotune_stop-autotune_start:.2f}s; best config selected: {self.best_config};"
146+
)
137147
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
138148
if config.pre_hook is not None:
139149
config.pre_hook(full_nargs)
@@ -224,7 +234,8 @@ def __str__(self):
224234
return ", ".join(res)
225235

226236

227-
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100):
237+
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100,
238+
report=False):
228239
"""
229240
Decorator for auto-tuning a :code:`triton.jit`'d function.
230241
@@ -261,10 +272,13 @@ def kernel(x_ptr, x_size, **META):
261272
:type warmup: int
262273
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
263274
:type rep: int
275+
:param report: Flag to enable printing the selected configuration
276+
:type report: bool
264277
"""
265278

266279
def decorator(fn):
267-
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep)
280+
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep,
281+
report)
268282

269283
return decorator
270284

0 commit comments

Comments
 (0)