Skip to content

Commit eea87c1

Browse files
authored
[Refactor] Refactor jit to _JitImplementation to support @tilelang.jit (#502)
* [Refactor] Rename `jit` class to `_JitImplementation` and improve debug path handling * Refactored the `jit` class to `_JitImplementation` for clarity and encapsulation. * Enhanced handling of `debug_root_path` to ensure it is correctly set as an absolute path when provided. * Updated the public `jit` function to serve as a decorator interface, allowing for both default and configured usage. * Added validation to ensure input tensors are contiguous in the Cython wrapper, improving error handling. * [Refactor] Improve formatting and handling in `_JitImplementation` and `jit` function * Refactored the `_JitImplementation` class to enhance readability by adjusting comment formatting and consolidating conditions for setting `debug_root_path`. * Updated the `jit` function signature for better alignment and clarity in parameter definitions. * Ensured consistent spacing and comments throughout the code for improved maintainability. * [Refactor] Update GEMM test parameters for performance optimization * Set num_stages to 0 and adjusted matrix dimensions in the GEMM test function to enhance performance and consistency across tests in test_tilelang_jit_gemm.py. * Reduced the number of threads used in the test to align with the updated configuration, improving overall test efficiency. * [Refactor] Enhance buffer error logging in layout inference * Updated the warning message in layout inference to provide clearer context when a buffer cannot be inferred due to its absence in the use list. This change improves the clarity of error reporting during layout inference operations. * Refactored tensor handling in the Cython wrapper to ensure input tensors are checked for contiguity before processing, enhancing error handling and robustness in tensor management. * bugfix
1 parent 30ca6da commit eea87c1

File tree

4 files changed

+125
-61
lines changed

4 files changed

+125
-61
lines changed

src/transform/layout_inference.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
301301

302302
// Check if buffer exists in use_list_
303303
if (!use_list_.count(buffer)) {
304-
LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
305-
<< "Potential mismatch between inference updates and "
306-
<< "use_list_.";
304+
LOG(WARNING) << "Layout inference failed for buffer " << buffer
305+
<< ". "
306+
<< "The buffer cannot be inferred with current layout "
307+
"inference rules.";
307308
continue;
308309
}
309310

testing/python/jit/test_tilelang_jit_gemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def run_gemm_kernel_jit(
7070
block_M,
7171
block_N,
7272
block_K,
73-
num_stages=3,
73+
num_stages=0,
7474
num_threads=128,
7575
):
7676
matmul_kernel = matmul_kernel_jit(
@@ -120,9 +120,9 @@ def test_gemm_f16f16f16_nn_kernel_jit():
120120
"float16",
121121
"float16",
122122
128,
123-
256,
123+
128,
124124
32,
125-
2,
125+
0,
126126
)
127127

128128

tilelang/jit/__init__.py

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __repr__(self):
114114
_RProg = TypeVar("_RProg", bound=Program)
115115

116116

117-
class jit:
117+
class _JitImplementation:
118118
# Overload __init__ to help type checkers understand the effect of return_program
119119
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
120120
@overload
@@ -190,76 +190,59 @@ def __init__(self,
190190
original program's result and the compiled kernel. If False, only the
191191
compiled kernel is returned (default: False).
192192
"""
193-
if debug_root_path is None:
194-
# This logic was previously under 'if debug and debug_root_path is None:'
195-
# Now, if debug_root_path is explicitly None, we don't try to set a default path.
196-
# If a user wants debugging, they must provide a path.
197-
pass
198-
elif not path.isabs(debug_root_path): # If a relative path is given, make it absolute
199-
try:
200-
# This assumes the file is part of a typical project structure
201-
base_path = path.dirname(path.dirname(path.dirname(__file__)))
202-
debug_root_path = path.join(base_path, debug_root_path)
203-
except NameError: # __file__ is not defined (e.g., in a REPL or notebook)
204-
# Fallback to making it absolute based on current working directory if __file__ fails
205-
debug_root_path = path.abspath(debug_root_path)
206-
207193
self.out_idx = out_idx
208194
self.execution_backend = execution_backend
209195
self.target = target
210196
self.target_host = target_host
211197
self.verbose = verbose
212198
self.pass_configs = pass_configs
213-
self.debug_root_path: Optional[str] = debug_root_path
214-
self.return_program: bool = return_program
199+
self.return_program = return_program # Stored from args
200+
201+
# Corrected debug_root_path handling
202+
self.debug_root_path = debug_root_path
203+
if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
204+
try:
205+
base_path = path.dirname(path.dirname(path.dirname(__file__)))
206+
self.debug_root_path = path.join(base_path, self.debug_root_path)
207+
except NameError:
208+
self.debug_root_path = path.abspath(self.debug_root_path)
209+
# If debug_root_path was None initially, it remains None.
215210

216211
# Type hint the caches
217212
self._program_cache: Dict[tuple, _RProg] = {}
218213
self._kernel_cache: Dict[tuple, Kernel] = {}
219214

220215
# Overload __call__ based on the value of self.return_program
221216
# This tells the type checker what the *wrapper* function will return.
222-
# The wrapper will take the same parameters P as the original function.
223-
224-
# Case 1: return_program is True
225217
@overload
226218
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
227-
# This signature is chosen by the type checker if self.return_program is True
228-
# (inferred from the __init__ call).
229219
...
230220

231-
# Case 2: return_program is False (or not specified, defaulting to False)
232221
@overload
233222
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
234-
# This signature is chosen if self.return_program is False.
235223
...
236224

237225
# Actual implementation of __call__
238226
def __call__(
239-
self, func: Union[Callable[_P, _RProg], PrimFunc]
240-
) -> Callable[_P, Any]: # Any for implementation flexibility
227+
self,
228+
func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original
229+
) -> Callable[_P, Any]:
241230

242231
@functools.wraps(func)
243-
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.kwargs
244-
# Create a hashable key. args is already a tuple.
245-
# For kwargs, convert to a sorted tuple of items to ensure consistent ordering.
232+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
246233
key_args_tuple = args
247234
key_kwargs_tuple = tuple(sorted(kwargs.items()))
248235
key = (key_args_tuple, key_kwargs_tuple)
249236

250-
# Check if both program and kernel are cached.
251-
# If program is not cached, we'll recompute both.
252-
# (The original check 'key not in self._program_cache or key not in self._kernel_cache'
253-
# implies that if either is missing, both are recomputed and stored.
254-
# A simpler 'key not in self._program_cache' would often suffice if they are always
255-
# added together.)
256-
if key not in self._program_cache: # Assuming if program isn't there, kernel isn't either or needs refresh
257-
if isinstance(func, PrimFunc):
258-
program_result = func
259-
elif isinstance(func, Callable):
260-
program_result = func(*args, **kwargs)
237+
if key not in self._program_cache:
238+
# Ensure 'func' (the original user function) is used correctly
239+
program_result_source = func
240+
if isinstance(program_result_source, PrimFunc):
241+
program_result = program_result_source
242+
elif callable(program_result_source):
243+
program_result = program_result_source(*args, **kwargs)
261244
else:
262-
raise ValueError(f"Invalid function type: {type(func)}")
245+
raise ValueError(f"Invalid function type: {type(program_result_source)}")
263246

264247
kernel_result = compile(
265248
program_result,
@@ -271,18 +254,16 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.k
271254
pass_configs=self.pass_configs,
272255
)
273256

274-
if self.debug_root_path: # Check if a path is provided
275-
func_name = func.__name__
257+
if self.debug_root_path:
258+
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
276259
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
277-
# Ensure the debug directory exists
278260
makedirs(self.debug_root_path, exist_ok=True)
279261
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
280262
print(kernel_result.get_kernel_source(), file=f)
281263

282264
self._program_cache[key] = program_result
283265
self._kernel_cache[key] = kernel_result
284266

285-
# Retrieve from cache (even if just populated)
286267
cached_program = self._program_cache[key]
287268
cached_kernel = self._kernel_cache[key]
288269

@@ -292,3 +273,82 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.k
292273
return cached_kernel
293274

294275
return wrapper
276+
277+
278+
def jit( # This is the new public interface
279+
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
280+
*, # Indicates subsequent arguments are keyword-only
281+
out_idx: Any = None,
282+
target: Union[str, Target] = "auto",
283+
target_host: Union[str, Target] = None,
284+
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
285+
verbose: bool = False,
286+
pass_configs: Optional[Dict[str, Any]] = None,
287+
debug_root_path: Optional[str] = None,
288+
return_program: bool = False):
289+
"""
290+
Just-In-Time (JIT) compiler decorator for TileLang functions.
291+
292+
This decorator can be used in two ways:
293+
1. Without arguments (e.g., `@tilelang.jit`):
294+
Applies JIT compilation with default settings.
295+
2. With arguments (e.g., `@tilelang.jit(target="cuda", return_program=True)`):
296+
Configures the JIT compilation process with the specified options.
297+
298+
Parameters
299+
----------
300+
func_or_out_idx : Any, optional
301+
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
302+
If using `@tilelang.jit` directly on a function, this argument is implicitly
303+
the function to be decorated (and `out_idx` will be `None`).
304+
target : Union[str, Target], optional
305+
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
306+
target_host : Union[str, Target], optional
307+
Target host for cross-compilation. Defaults to None.
308+
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
309+
Backend for kernel execution and argument passing. Defaults to "cython".
310+
verbose : bool, optional
311+
Enables verbose logging during compilation. Defaults to False.
312+
pass_configs : Optional[Dict[str, Any]], optional
313+
Configurations for TVM's pass context. Defaults to None.
314+
debug_root_path : Optional[str], optional
315+
Directory to save compiled kernel source for debugging. Defaults to None.
316+
return_program : bool, optional
317+
If True, the decorated function returns a tuple (original program's result, compiled kernel).
318+
Otherwise, only the compiled kernel is returned. Defaults to False.
319+
320+
Returns
321+
-------
322+
Callable
323+
Either a JIT-compiled wrapper around the input function, or a configured decorator
324+
instance that can then be applied to a function.
325+
"""
326+
if callable(func):
327+
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
328+
# Create a default _JitImplementation instance and apply it to the function.
329+
default_decorator = _JitImplementation(
330+
out_idx=out_idx, # Explicitly None for the default case
331+
target=target,
332+
target_host=target_host,
333+
execution_backend=execution_backend,
334+
verbose=verbose,
335+
pass_configs=pass_configs,
336+
debug_root_path=debug_root_path,
337+
return_program=return_program)
338+
return default_decorator(func)
339+
elif isinstance(func, PrimFunc):
340+
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
341+
else:
342+
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
343+
# Create a _JitImplementation instance with the provided/defaulted arguments.
344+
# This instance is a decorator that will be applied to the function later.
345+
configured_decorator = _JitImplementation(
346+
out_idx=out_idx, # Pass along; could be an actual out_idx or None
347+
target=target,
348+
target_host=target_host,
349+
execution_backend=execution_backend,
350+
verbose=verbose,
351+
pass_configs=pass_configs,
352+
debug_root_path=debug_root_path,
353+
return_program=return_program)
354+
return configured_decorator

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,23 @@ cdef class CythonKernelWrapper:
116116
# Convert tensor pointers to C void pointers for kernel call
117117
call_args = []
118118
for i in range(len(tensor_list)):
119-
if isinstance(tensor_list[i], torch.Tensor):
120-
call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr()))
121-
elif isinstance(tensor_list[i], int):
119+
tensor = tensor_list[i]
120+
if isinstance(tensor, torch.Tensor):
121+
if not tensor.is_contiguous():
122+
raise ValueError(f"Input tensor at index {i} must be contiguous")
123+
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
124+
elif isinstance(tensor, int):
122125
# Dynamic symbolics which are passed as integer arguments
123126
if i in self.ptr_map:
124-
call_args.append(ctypes.c_void_p(tensor_list[i]))
127+
call_args.append(ctypes.c_void_p(tensor))
125128
else:
126-
call_args.append(tensor_list[i])
127-
elif isinstance(tensor_list[i], float):
128-
call_args.append(ctypes.c_float(tensor_list[i]))
129-
elif isinstance(tensor_list[i], bool):
130-
call_args.append(ctypes.c_bool(tensor_list[i]))
129+
call_args.append(tensor)
130+
elif isinstance(tensor, float):
131+
call_args.append(ctypes.c_float(tensor))
132+
elif isinstance(tensor, bool):
133+
call_args.append(ctypes.c_bool(tensor))
131134
else:
132-
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
135+
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
133136

134137
# Check buffer device
135138
# cdef str tensor_list_device_type = tensor_list[0].device.type

0 commit comments

Comments
 (0)