33This module provides functionality for auto-tuning tilelang programs, including JIT compilation
44and performance optimization through configuration search.
55"""
6+ from __future__ import annotations
67
78import tilelang
89from tilelang import tvm as tvm
910from tvm .tir import PrimFunc , Var
1011from tvm .target import Target
1112import inspect
1213from functools import partial
13- from typing import (Callable , List , Literal , Any , Optional , Union , Dict , overload , Tuple )
14+ from typing import (Callable , Literal , Any , overload )
1415from tqdm import tqdm
1516import logging
1617import functools
@@ -103,8 +104,8 @@ class AutoTuner:
103104 compile_args = CompileArgs ()
104105 profile_args = ProfileArgs ()
105106
106- _kernel_parameters : Optional [ Tuple [ str , ...]] = None
107- _function_parameters : Optional [ Dict [ str , Any ]] = None
107+ _kernel_parameters : tuple [ str , ...] | None = None
108+ _function_parameters : dict [ str , Any ] | None = None
108109 _lock = threading .Lock () # For thread safety
109110 _memory_cache = {} # In-memory cache dictionary
110111 cache_dir : Path = Path (env .TILELANG_CACHE_DIR ) / "autotuner"
@@ -131,12 +132,12 @@ def from_kernel(cls, kernel: Callable, configs):
131132 return cls (kernel , configs )
132133
133134 def set_compile_args (self ,
134- out_idx : Union [ List [ int ], int , None ] = None ,
135+ out_idx : list [ int ] | int | None = None ,
135136 target : Literal ['auto' , 'cuda' , 'hip' ] = 'auto' ,
136137 execution_backend : Literal ["dlpack" , "ctypes" , "cython" ] = "cython" ,
137- target_host : Union [ str , Target ] = None ,
138+ target_host : str | Target = None ,
138139 verbose : bool = False ,
139- pass_configs : Optional [ Dict [ str , Any ]] = None ):
140+ pass_configs : dict [ str , Any ] | None = None ):
140141 """Set compilation arguments for the auto-tuner.
141142
142143 Args:
@@ -223,12 +224,12 @@ def set_profile_args(self,
223224
224225 return self
225226
226- def set_kernel_parameters (self , k_parameters : Tuple [str , ...], f_parameters : Dict [str , Any ]):
227+ def set_kernel_parameters (self , k_parameters : tuple [str , ...], f_parameters : dict [str , Any ]):
227228 # for cache key generation
228229 self ._kernel_parameters = k_parameters
229230 self ._function_parameters = f_parameters
230231
231- def generate_cache_key (self , parameters : Dict [str , Any ]) -> Optional [ AutotuneResult ] :
232+ def generate_cache_key (self , parameters : dict [str , Any ]) -> AutotuneResult | None :
232233 """Generate a cache key for the auto-tuning process.
233234 """
234235
@@ -307,8 +308,8 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
307308 return result
308309
309310 best_latency : float = 1e8
310- best_config : Optional [ Dict [ str , Any ]] = None
311- best_kernel : Optional [ tilelang .JITKernel ] = None
311+ best_config : dict [ str , Any ] | None = None
312+ best_kernel : tilelang .JITKernel | None = None
312313
313314 def _compile (** config_arg ) -> tilelang .JITKernel :
314315 compile_args = self .compile_args
@@ -591,7 +592,7 @@ class _AutoTunerImplementation:
591592 warmup : int = 25
592593 rep : int = 100
593594 timeout : int = 100
594- configs : Union [ Dict , Callable ] = None
595+ configs : dict | Callable = None
595596 supply_type : tilelang .TensorSupplyType = tilelang .TensorSupplyType .Auto
596597 ref_prog : Callable = None
597598 supply_prog : Callable = None
@@ -603,7 +604,7 @@ class _AutoTunerImplementation:
603604 cache_input_tensors : bool = False
604605
605606 def __init__ (self ,
606- configs : Union [ Dict , Callable ] ,
607+ configs : dict | Callable ,
607608 warmup : int = 25 ,
608609 rep : int = 100 ,
609610 timeout : int = 100 ,
@@ -653,12 +654,12 @@ def __init__(self,
653654 self .cache_input_tensors = cache_input_tensors # Reuse inputs
654655
655656 # Cache for storing tuned kernel implementations
656- self ._tuner_cache : Dict [tuple , tilelang .JITKernel ] = {} # (args, kwargs) -> compiled kernel
657+ self ._tuner_cache : dict [tuple , tilelang .JITKernel ] = {} # (args, kwargs) -> compiled kernel
657658
658659 # This tells the type checker what the *wrapper* function will return.
659660 # this is for linting, please do not remove it.
660661 @overload
661- def __call__ (self , fn : Callable [_P , _RProg ]) -> Callable [_P , Tuple [_RProg , AutotuneResult ]]:
662+ def __call__ (self , fn : Callable [_P , _RProg ]) -> Callable [_P , tuple [_RProg , AutotuneResult ]]:
662663 ...
663664
664665 @overload
@@ -720,9 +721,9 @@ def jit_compile(**config_arg):
720721
721722
722723def autotune ( # This is the new public interface
723- func : Union [ Callable [_P , _RProg ], PrimFunc , None ] = None ,
724+ func : Callable [_P , _RProg ] | PrimFunc | None = None ,
724725 * , # Indicates subsequent arguments are keyword-only
725- configs : Union [ Dict , Callable ] ,
726+ configs : dict | Callable ,
726727 # profile arguments
727728 warmup : int = 25 ,
728729 rep : int = 100 ,
0 commit comments