diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 09bfa139aecc..facb640d9b34 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -1,6 +1,7 @@ from __future__ import annotations import builtins +import os import time from typing import Dict @@ -109,6 +110,7 @@ def kernel_call(): def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True if len(self.configs) > 1: all_args = {**self.nargs, **kwargs} _args = [] @@ -122,6 +124,7 @@ def run(self, *args, **kwargs): key = tuple(key) if key not in self.cache: # prune configs + used_cached_result = False pruned_configs = self.prune_configs(kwargs) bench_start = time.time() timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} @@ -134,6 +137,9 @@ def run(self, *args, **kwargs): else: config = self.configs[0] self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.fn} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: config.pre_hook(full_nargs)