Skip to content

Commit feb13ca

Browse files
bringleinjlebar
andauthored
[Frontend] Add TRITON_PRINT_AUTOTUNING flag (#3411)
When the `TRITON_PRINT_AUTOTUNING` envvar is set to 1, we print a message after autotuning a kernel. For example: ``` Triton autotuning for function JITFunction(__main__:fused_add_rmsnorm_triton) finished after 4.15s; best config selected: BLOCK_N_SIZE: 512, num_warps: 8, num_ctas: 1, num_stages: 3; ``` Co-authored-by: Justin Lebar <justin.lebar@gmail.com>
1 parent 8237f1b commit feb13ca

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

python/triton/runtime/autotuner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import builtins
4+
import os
45
import time
56
from typing import Dict
67

@@ -109,6 +110,7 @@ def kernel_call():
109110

110111
def run(self, *args, **kwargs):
111112
self.nargs = dict(zip(self.arg_names, args))
113+
used_cached_result = True
112114
if len(self.configs) > 1:
113115
all_args = {**self.nargs, **kwargs}
114116
_args = []
@@ -122,6 +124,7 @@ def run(self, *args, **kwargs):
122124
key = tuple(key)
123125
if key not in self.cache:
124126
# prune configs
127+
used_cached_result = False
125128
pruned_configs = self.prune_configs(kwargs)
126129
bench_start = time.time()
127130
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
@@ -134,6 +137,9 @@ def run(self, *args, **kwargs):
134137
else:
135138
config = self.configs[0]
136139
self.best_config = config
140+
if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
141+
print(f"Triton autotuning for function {self.fn} finished after "
142+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
137143
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
138144
if config.pre_hook is not None:
139145
config.pre_hook(full_nargs)

0 commit comments

Comments
 (0)