Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tilelang/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ def wrapper(*args, **kwargs):
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)

compile_arguments = fn(__return_compile_arguments=True)

autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
supply_type=self.supply_type,
Expand All @@ -563,13 +565,22 @@ def jit_compile(**config_arg):
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)

autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)

autotuner.run = partial(autotuner.run, warmup, rep, timeout)

artifact = autotuner.run()

self._tuner_cache[key] = artifact.kernel

return self._tuner_cache[key]
Expand Down
6 changes: 2 additions & 4 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CompileArgs:
"tl.disable_safe_memory_legalize": bool, default: False
"""

out_idx: Union[List[int], int] = -1
out_idx: Optional[Union[List[int], int]] = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
Expand All @@ -67,8 +67,6 @@ def compile_program(self, program: PrimFunc):

def __hash__(self):
data = {
"out_idx":
self.out_idx,
"execution_backend":
self.execution_backend,
"target":
Expand Down Expand Up @@ -208,7 +206,7 @@ def _load_kernel_from_disk(
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
out_idx: Optional[Union[List[int], int]] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
Expand Down
14 changes: 13 additions & 1 deletion tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compile(

class _JitImplementation:

out_idx: Any
out_idx: Optional[Union[List[int], int]]
target: Union[str, Target]
target_host: Union[str, Target]
execution_backend: Literal["dlpack", "ctypes", "cython"]
Expand Down Expand Up @@ -168,6 +168,18 @@ def __call__(
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments:
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
'target': self.target,
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
}
return compile_args

key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
Expand Down
Loading