From 4b48250a74495f04d5fff555d708a2ba5f5f5190 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 4 Jun 2025 20:43:53 +0800 Subject: [PATCH 1/4] [Enhancement] Update AutoTuner and JIT compilation arguments * Added functionality to return compile arguments in the JIT implementation, enhancing the autotuner's caching capabilities. * Modified `CompileArgs` and `AutotuneResult` classes to support optional `out_idx` parameter, improving flexibility in compile argument handling. * Refactored the `_AutoTunerImplementation` to utilize the new compile arguments, ensuring better integration and performance during tuning processes. --- tilelang/autotuner/__init__.py | 11 +++++++++++ tilelang/autotuner/param.py | 6 ++---- tilelang/jit/__init__.py | 21 ++++++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 5cead02c3..21e744d84 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -552,6 +552,8 @@ def wrapper(*args, **kwargs): def jit_compile(**config_arg): return fn(*args, **kwargs, __tune_params=config_arg) + compile_arguments = fn(__return_compile_arguments=True) + autotuner = AutoTuner( fn, configs=configs).set_profile_args( supply_type=self.supply_type, @@ -563,13 +565,22 @@ def jit_compile(**config_arg): skip_check=self.skip_check, manual_check_prog=self.manual_check_prog, cache_input_tensors=self.cache_input_tensors, + ).set_compile_args( + out_idx=compile_arguments['out_idx'], + execution_backend=compile_arguments['execution_backend'], + target=compile_arguments['target'], + target_host=compile_arguments['target_host'], + verbose=compile_arguments['verbose'], + pass_configs=compile_arguments['pass_configs'], ) + autotuner.jit_compile = jit_compile autotuner.set_kernel_parameters(key) autotuner.run = partial(autotuner.run, warmup, rep, timeout) artifact = autotuner.run() + self._tuner_cache[key] = artifact.kernel return self._tuner_cache[key] diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index 257c97357..d9895004d 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -49,7 +49,7 @@ class CompileArgs: "tl.disable_safe_memory_legalize": bool, default: False """ - out_idx: Union[List[int], int] = -1 + out_idx: Optional[Union[List[int], int]] = None execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" target: Literal['auto', 'cuda', 'hip'] = 'auto' target_host: Union[str, Target] = None @@ -67,8 +67,6 @@ def compile_program(self, program: PrimFunc): def __hash__(self): data = { - "out_idx": - self.out_idx, "execution_backend": self.execution_backend, "target": @@ -208,7 +206,7 @@ def _load_kernel_from_disk( cache_path: Path, target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, - out_idx: List[int] = None, + out_idx: Optional[List[int]] = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", pass_configs: dict = None, func: Callable = None, diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index c8483c9ad..01d8f4e96 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -82,7 +82,7 @@ def compile( class _JitImplementation: - out_idx: Any + out_idx: Optional[Any] target: Union[str, Target] target_host: Union[str, Target] execution_backend: Literal["dlpack", "ctypes", "cython"] @@ -168,6 +168,25 @@ def __call__( def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Separate out the tuning parameters from the user's kwargs tune_params = kwargs.pop('__tune_params', {}) + # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache + return_compile_arguments = kwargs.pop('__return_compile_arguments', False) + if return_compile_arguments: + # out_idx: Optional[Union[List[int], int]] = None + # execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" + # target: Literal['auto', 'cuda', 'hip'] = 'auto' + # target_host: Union[str, Target] = None + # verbose: bool = False + # pass_configs: Optional[Dict[str, Any]] = None + + compile_args = { + 'out_idx': self.out_idx, + 'execution_backend': self.execution_backend, + 'target': self.target, + 'target_host': self.target_host, + 'verbose': self.verbose, + 'pass_configs': self.pass_configs, + } + return compile_args key_args_tuple = args key_kwargs_tuple = tuple(sorted(kwargs.items())) From f756dc2206d8781bbb448db2d62f88514b95a7c0 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 4 Jun 2025 20:47:56 +0800 Subject: [PATCH 2/4] Update tilelang/autotuner/param.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tilelang/autotuner/param.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index d9895004d..178493856 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -206,7 +206,7 @@ def _load_kernel_from_disk( cache_path: Path, target: Union[str, Target] = "auto", target_host: Union[str, Target] = None, - out_idx: Optional[List[int]] = None, + out_idx: Optional[Union[List[int], int]] = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", pass_configs: dict = None, func: Callable = None, From 24a74c43647b8f49ca7132f29a9d3be0dbdaf369 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 4 Jun 2025 20:48:43 +0800 Subject: [PATCH 3/4] remove redundant comments --- tilelang/jit/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 01d8f4e96..2d2121c13 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -171,13 +171,6 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache return_compile_arguments = kwargs.pop('__return_compile_arguments', False) if return_compile_arguments: - # out_idx: Optional[Union[List[int], int]] = None - # execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" - # target: Literal['auto', 'cuda', 'hip'] = 'auto' - # target_host: Union[str, Target] = None - # verbose: bool = False - # pass_configs: Optional[Dict[str, Any]] = None - compile_args = { 'out_idx': self.out_idx, 'execution_backend': self.execution_backend, From 698df64480bc78ac3e8ce04ab49665546cbaef96 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 4 Jun 2025 20:50:28 +0800 Subject: [PATCH 4/4] Update tilelang/jit/__init__.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tilelang/jit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 2d2121c13..3b05de45e 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -82,7 +82,7 @@ def compile( class _JitImplementation: - out_idx: Optional[Any] + out_idx: Optional[Union[List[int], int]] target: Union[str, Target] target_host: Union[str, Target] execution_backend: Literal["dlpack", "ctypes", "cython"]