@@ -114,7 +114,7 @@ def __repr__(self):
114114_RProg = TypeVar ("_RProg" , bound = Program )
115115
116116
117- class jit :
117+ class _JitImplementation :
118118 # Overload __init__ to help type checkers understand the effect of return_program
119119 # The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
120120 @overload
@@ -190,76 +190,59 @@ def __init__(self,
190190 original program's result and the compiled kernel. If False, only the
191191 compiled kernel is returned (default: False).
192192 """
193- if debug_root_path is None :
194- # This logic was previously under 'if debug and debug_root_path is None:'
195- # Now, if debug_root_path is explicitly None, we don't try to set a default path.
196- # If a user wants debugging, they must provide a path.
197- pass
198- elif not path .isabs (debug_root_path ): # If a relative path is given, make it absolute
199- try :
200- # This assumes the file is part of a typical project structure
201- base_path = path .dirname (path .dirname (path .dirname (__file__ )))
202- debug_root_path = path .join (base_path , debug_root_path )
203- except NameError : # __file__ is not defined (e.g., in a REPL or notebook)
204- # Fallback to making it absolute based on current working directory if __file__ fails
205- debug_root_path = path .abspath (debug_root_path )
206-
207193 self .out_idx = out_idx
208194 self .execution_backend = execution_backend
209195 self .target = target
210196 self .target_host = target_host
211197 self .verbose = verbose
212198 self .pass_configs = pass_configs
213- self .debug_root_path : Optional [str ] = debug_root_path
214- self .return_program : bool = return_program
199+ self .return_program = return_program # Stored from args
200+
201+ # Corrected debug_root_path handling
202+ self .debug_root_path = debug_root_path
203+ if self .debug_root_path is not None and not path .isabs (self .debug_root_path ):
204+ try :
205+ base_path = path .dirname (path .dirname (path .dirname (__file__ )))
206+ self .debug_root_path = path .join (base_path , self .debug_root_path )
207+ except NameError :
208+ self .debug_root_path = path .abspath (self .debug_root_path )
209+ # If debug_root_path was None initially, it remains None.
215210
216211 # Type hint the caches
217212 self ._program_cache : Dict [tuple , _RProg ] = {}
218213 self ._kernel_cache : Dict [tuple , Kernel ] = {}
219214
220215 # Overload __call__ based on the value of self.return_program
221216 # This tells the type checker what the *wrapper* function will return.
222- # The wrapper will take the same parameters P as the original function.
223-
224- # Case 1: return_program is True
225217 @overload
226218 def __call__ (self , func : Callable [_P , _RProg ]) -> Callable [_P , Tuple [_RProg , Kernel ]]:
227- # This signature is chosen by the type checker if self.return_program is True
228- # (inferred from the __init__ call).
229219 ...
230220
231- # Case 2: return_program is False (or not specified, defaulting to False)
232221 @overload
233222 def __call__ (self , func : Callable [_P , _RProg ]) -> Callable [_P , Kernel ]:
234- # This signature is chosen if self.return_program is False.
235223 ...
236224
237225 # Actual implementation of __call__
238226 def __call__ (
239- self , func : Union [Callable [_P , _RProg ], PrimFunc ]
240- ) -> Callable [_P , Any ]: # Any for implementation flexibility
227+ self ,
228+ func : Callable [_P , _RProg ] # func is Union[Callable[_P, _RProg], PrimFunc] in original
229+ ) -> Callable [_P , Any ]:
241230
242231 @functools .wraps (func )
243- def wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> Any : # Use _P.args and _P.kwargs
244- # Create a hashable key. args is already a tuple.
245- # For kwargs, convert to a sorted tuple of items to ensure consistent ordering.
232+ def wrapper (* args : _P .args , ** kwargs : _P .kwargs ) -> Any :
246233 key_args_tuple = args
247234 key_kwargs_tuple = tuple (sorted (kwargs .items ()))
248235 key = (key_args_tuple , key_kwargs_tuple )
249236
250- # Check if both program and kernel are cached.
251- # If program is not cached, we'll recompute both.
252- # (The original check 'key not in self._program_cache or key not in self._kernel_cache'
253- # implies that if either is missing, both are recomputed and stored.
254- # A simpler 'key not in self._program_cache' would often suffice if they are always
255- # added together.)
256- if key not in self ._program_cache : # Assuming if program isn't there, kernel isn't either or needs refresh
257- if isinstance (func , PrimFunc ):
258- program_result = func
259- elif isinstance (func , Callable ):
260- program_result = func (* args , ** kwargs )
237+ if key not in self ._program_cache :
238+ # Ensure 'func' (the original user function) is used correctly
239+ program_result_source = func
240+ if isinstance (program_result_source , PrimFunc ):
241+ program_result = program_result_source
242+ elif callable (program_result_source ):
243+ program_result = program_result_source (* args , ** kwargs )
261244 else :
262- raise ValueError (f"Invalid function type: { type (func )} " )
245+ raise ValueError (f"Invalid function type: { type (program_result_source )} " )
263246
264247 kernel_result = compile (
265248 program_result ,
@@ -271,18 +254,16 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.k
271254 pass_configs = self .pass_configs ,
272255 )
273256
274- if self .debug_root_path : # Check if a path is provided
275- func_name = func . __name__
257+ if self .debug_root_path :
258+ func_name = getattr ( func , ' __name__' , 'jit_kernel' ) # Use func for name
276259 kernel_file = f'tilelang_jit_kernel_{ func_name } .c'
277- # Ensure the debug directory exists
278260 makedirs (self .debug_root_path , exist_ok = True )
279261 with open (path .join (self .debug_root_path , kernel_file ), 'w' ) as f :
280262 print (kernel_result .get_kernel_source (), file = f )
281263
282264 self ._program_cache [key ] = program_result
283265 self ._kernel_cache [key ] = kernel_result
284266
285- # Retrieve from cache (even if just populated)
286267 cached_program = self ._program_cache [key ]
287268 cached_kernel = self ._kernel_cache [key ]
288269
@@ -292,3 +273,82 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.k
292273 return cached_kernel
293274
294275 return wrapper
276+
277+
278+ def jit ( # This is the new public interface
279+ func : Union [Callable [_P , _RProg ], PrimFunc , None ] = None ,
280+ * , # Indicates subsequent arguments are keyword-only
281+ out_idx : Any = None ,
282+ target : Union [str , Target ] = "auto" ,
283+ target_host : Union [str , Target ] = None ,
284+ execution_backend : Literal ["dlpack" , "ctypes" , "cython" ] = "cython" ,
285+ verbose : bool = False ,
286+ pass_configs : Optional [Dict [str , Any ]] = None ,
287+ debug_root_path : Optional [str ] = None ,
288+ return_program : bool = False ):
289+ """
290+ Just-In-Time (JIT) compiler decorator for TileLang functions.
291+
292+ This decorator can be used in two ways:
293+ 1. Without arguments (e.g., `@tilelang.jit`):
294+ Applies JIT compilation with default settings.
295+ 2. With arguments (e.g., `@tilelang.jit(target="cuda", return_program=True)`):
296+ Configures the JIT compilation process with the specified options.
297+
298+ Parameters
299+ ----------
300+ func_or_out_idx : Any, optional
301+ If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
302+ If using `@tilelang.jit` directly on a function, this argument is implicitly
303+ the function to be decorated (and `out_idx` will be `None`).
304+ target : Union[str, Target], optional
305+ Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
306+ target_host : Union[str, Target], optional
307+ Target host for cross-compilation. Defaults to None.
308+ execution_backend : Literal["dlpack", "ctypes", "cython"], optional
309+ Backend for kernel execution and argument passing. Defaults to "cython".
310+ verbose : bool, optional
311+ Enables verbose logging during compilation. Defaults to False.
312+ pass_configs : Optional[Dict[str, Any]], optional
313+ Configurations for TVM's pass context. Defaults to None.
314+ debug_root_path : Optional[str], optional
315+ Directory to save compiled kernel source for debugging. Defaults to None.
316+ return_program : bool, optional
317+ If True, the decorated function returns a tuple (original program's result, compiled kernel).
318+ Otherwise, only the compiled kernel is returned. Defaults to False.
319+
320+ Returns
321+ -------
322+ Callable
323+ Either a JIT-compiled wrapper around the input function, or a configured decorator
324+ instance that can then be applied to a function.
325+ """
326+ if callable (func ):
327+ # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
328+ # Create a default _JitImplementation instance and apply it to the function.
329+ default_decorator = _JitImplementation (
330+ out_idx = out_idx , # Explicitly None for the default case
331+ target = target ,
332+ target_host = target_host ,
333+ execution_backend = execution_backend ,
334+ verbose = verbose ,
335+ pass_configs = pass_configs ,
336+ debug_root_path = debug_root_path ,
337+ return_program = return_program )
338+ return default_decorator (func )
339+ elif isinstance (func , PrimFunc ):
340+ raise ValueError ("Use tilelang.jit to decorate prim_func is not supported yet." )
341+ else :
342+ # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
343+ # Create a _JitImplementation instance with the provided/defaulted arguments.
344+ # This instance is a decorator that will be applied to the function later.
345+ configured_decorator = _JitImplementation (
346+ out_idx = out_idx , # Pass along; could be an actual out_idx or None
347+ target = target ,
348+ target_host = target_host ,
349+ execution_backend = execution_backend ,
350+ verbose = verbose ,
351+ pass_configs = pass_configs ,
352+ debug_root_path = debug_root_path ,
353+ return_program = return_program )
354+ return configured_decorator
0 commit comments