Skip to content

Commit 7d8b4b6

Browse files
committed
use existing bench_time; rename flag; update docstring
1 parent 7e94aae commit 7d8b4b6

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

python/triton/runtime/autotuner.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
prune_configs_by: Dict = None,
2323
warmup=25,
2424
rep=100,
25-
report=False,
25+
print_autotune_stats=False,
2626
):
2727
"""
2828
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -78,7 +78,7 @@ def _post_hook(args):
7878
self.fn = fn
7979
self.num_warmups = warmup
8080
self.num_reps = rep
81-
self.report = report
81+
self.print_autotune_stats = print_autotune_stats
8282

8383
def _bench(self, *args, config, **meta):
8484
# check for conflicts, i.e. meta-parameters both provided
@@ -111,7 +111,6 @@ def kernel_call():
111111

112112
def run(self, *args, **kwargs):
113113
self.nargs = dict(zip(self.arg_names, args))
114-
autotune_start = time.time()
115114
used_cached_result = True
116115
if len(self.configs) > 1:
117116
all_args = {**self.nargs, **kwargs}
@@ -139,10 +138,9 @@ def run(self, *args, **kwargs):
139138
else:
140139
config = self.configs[0]
141140
self.best_config = config
142-
if self.report and not used_cached_result:
143-
autotune_stop = time.time()
141+
if self.print_autotune_stats and not used_cached_result:
144142
print(
145-
f"Autotuner for function {self.fn} finished after {autotune_stop-autotune_start:.2f}s; best config selected: {self.best_config};"
143+
f"Autotuner for function {self.fn} finished after {self.bench_time:.2f}s; best config selected: {self.best_config};"
146144
)
147145
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
148146
if config.pre_hook is not None:
@@ -235,7 +233,7 @@ def __str__(self):
235233

236234

237235
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, warmup=25, rep=100,
238-
report=False):
236+
print_autotune_stats=False):
239237
"""
240238
Decorator for auto-tuning a :code:`triton.jit`'d function.
241239
@@ -272,13 +270,13 @@ def kernel(x_ptr, x_size, **META):
272270
:type warmup: int
273271
:param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100.
274272
:type rep: int
275-
:param report: Flag to enable printing the selected configuration
273+
:param print_autotune_stats: If set to true, a log message will be printed after each autotune evaluation containing the benchmark time and the selected configuration.
276274
:type report: bool
277275
"""
278276

279277
def decorator(fn):
280278
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, prune_configs_by, warmup, rep,
281-
report)
279+
print_autotune_stats)
282280

283281
return decorator
284282

0 commit comments

Comments
 (0)