diff --git a/testing/python/components/test_tilelang_env.py b/testing/python/components/test_tilelang_env.py new file mode 100644 index 000000000..9bc767943 --- /dev/null +++ b/testing/python/components/test_tilelang_env.py @@ -0,0 +1,17 @@ +import tilelang +import os + + +def test_env_var(): + # test default value + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1" + # test forced value + os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0" + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0" + # test forced value with class method + tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1" + assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1" + + +if __name__ == "__main__": + test_env_var() diff --git a/testing/python/pass_config/test_tilelang_pass_config_disable_warp_specialized.py b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py similarity index 100% rename from testing/python/pass_config/test_tilelang_pass_config_disable_warp_specialized.py rename to testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 0c0146bdc..2720e3488 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -53,8 +53,8 @@ def _init_logger(): logger = logging.getLogger(__name__) -from .env import SKIP_LOADING_TILELANG_SO from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 +from .env import env as env # noqa: F401 import tvm import tvm.base @@ -76,12 +76,12 @@ def _load_tile_lang_lib(): # only load once here -if SKIP_LOADING_TILELANG_SO == "0": +if env.SKIP_LOADING_TILELANG_SO == "0": _LIB, _LIB_PATH = _load_tile_lang_lib() from .jit import jit, JITKernel, compile # noqa: F401 from .profiler import Profiler # noqa: F401 -from .cache import cached # noqa: F401 +from .cache import clear_cache # noqa: F401 from .utils import ( TensorSupplyType, # noqa: F401 diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 008807a79..2ed38c58c 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -25,13 +25,7 @@ import traceback from pathlib import Path -from tilelang.env import ( - TILELANG_CACHE_DIR, - TILELANG_AUTO_TUNING_CPU_UTILITIES, - TILELANG_AUTO_TUNING_CPU_COUNTS, - TILELANG_AUTO_TUNING_MAX_CPU_COUNT, - is_cache_enabled, -) +from tilelang import env from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.capture import get_autotune_inputs from tilelang.jit.param import _P, _RProg @@ -111,7 +105,7 @@ class AutoTuner: _kernel_parameters: Optional[Tuple[str, ...]] = None _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary - cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner" + cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" def __init__(self, fn: Callable, configs): self.fn = fn @@ -285,7 +279,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): key = self.generate_cache_key(parameters) with self._lock: - if is_cache_enabled(): + if env.is_cache_enabled(): # First check in-memory cache if key in self._memory_cache: logger.warning("Found kernel in memory cache. For better performance," \ @@ -437,9 +431,9 @@ def shape_equal(a, b): return autotuner_result # get the cpu count available_cpu_count = get_available_cpu_count() - cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES) - cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS) - max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT) + cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) + cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS) + max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT) if cpu_counts > 0: num_workers = min(cpu_counts, available_cpu_count) logger.info( @@ -543,7 +537,7 @@ def device_wrapper(func, device, **config_arg): logger.warning("DLPack backend does not support cache saving to disk.") else: with self._lock: - if is_cache_enabled(): + if env.is_cache_enabled(): self._save_result_to_disk(key, autotuner_result) self._memory_cache[key] = autotuner_result diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index 43d9a2202..2a81d88b6 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -4,8 +4,8 @@ from tvm.target import Target from tvm.tir import PrimFunc from tilelang.jit import JITKernel +from tilelang import env from .kernel_cache import KernelCache -from tilelang.env import TILELANG_CLEAR_CACHE # Create singleton instance of KernelCache _kernel_cache_instance = KernelCache() @@ -44,5 +44,5 @@ def clear_cache(): _kernel_cache_instance.clear_cache() -if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): +if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): clear_cache() diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 02b1e0086..caf201f4a 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -14,7 +14,7 @@ from tvm.tir import PrimFunc from tilelang.engine.param import KernelParam -from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled +from tilelang import env from tilelang.jit import JITKernel from tilelang.version import __version__ @@ -61,8 +61,8 @@ def __new__(cls): @staticmethod def _create_dirs(): - os.makedirs(TILELANG_CACHE_DIR, exist_ok=True) - os.makedirs(TILELANG_TMP_DIR, exist_ok=True) + os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True) + os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) def _generate_key( self, @@ -132,7 +132,7 @@ def cached( Returns: JITKernel: The compiled kernel, either freshly compiled or from cache """ - if not is_cache_enabled(): + if not env.is_cache_enabled(): return JITKernel( func, out_idx=out_idx, @@ -190,7 +190,7 @@ def cached( self.logger.warning("DLPack backend does not support cache saving to disk.") else: with self._lock: - if is_cache_enabled(): + if env.is_cache_enabled(): self._save_kernel_to_disk(key, kernel, func, verbose) # Store in memory cache after compilation @@ -215,7 +215,7 @@ def _get_cache_path(self, key: str) -> str: Returns: str: Absolute path to the cache directory for this kernel. """ - return os.path.join(TILELANG_CACHE_DIR, key) + return os.path.join(env.TILELANG_CACHE_DIR, key) @staticmethod def _load_binary(path: str): @@ -226,7 +226,7 @@ def _load_binary(path: str): @staticmethod def _safe_write_file(path: str, mode: str, operation: Callable): # Random a temporary file within the same FS as the cache directory - temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}") + temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}") with open(temp_path, mode) as temp_file: operation(temp_file) @@ -396,7 +396,7 @@ def _clear_disk_cache(self): """ try: # Delete the entire cache directory - shutil.rmtree(TILELANG_CACHE_DIR) + shutil.rmtree(env.TILELANG_CACHE_DIR) # Re-create the cache directory KernelCache._create_dirs() diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 5cfe90ced..c0ee6b685 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -6,7 +6,7 @@ import os import subprocess import warnings -from ..env import CUDA_HOME +from tilelang.env import CUDA_HOME import tvm.ffi from tvm.target import Target diff --git a/tilelang/env.py b/tilelang/env.py index adc8860e9..07d707e15 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -4,9 +4,21 @@ import logging import shutil import glob +from dataclasses import dataclass +from typing import Optional logger = logging.getLogger(__name__) +# SETUP ENVIRONMENT VARIABLES +CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." +COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = ( + "Composable Kernel is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." +TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." +TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") + def _find_cuda_home() -> str: """Find the CUDA install path. @@ -46,76 +58,200 @@ def _find_rocm_home() -> str: return rocm_home if rocm_home is not None else "" -def _initialize_torch_cuda_arch_flags(): - import os - from tilelang.contrib import nvcc - from tilelang.utils.target import determine_target - - target = determine_target(return_object=True) - # create tmp source file for torch cpp extension - compute_version = nvcc.get_target_compute_version(target) - major, minor = nvcc.parse_compute_version(compute_version) - - # set TORCH_CUDA_ARCH_LIST - os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" - - -CUDA_HOME = _find_cuda_home() -ROCM_HOME = _find_rocm_home() - -CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None) -COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) -TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) -TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None) -TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None) -TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0] - -TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR", - os.path.expanduser("~/.tilelang/cache")) -TILELANG_TMP_DIR: str = os.path.join(TILELANG_CACHE_DIR, "tmp") +# Cache control +class CacheState: + """Class to manage global kernel caching state.""" + _enabled = True -# Print the kernel name on every compilation -TILELANG_PRINT_ON_COMPILATION: str = os.environ.get("TILELANG_PRINT_COMPILATION", "0") + @classmethod + def enable(cls): + """Enable kernel caching globally.""" + cls._enabled = True -# Auto-clear cache if environment variable is set -TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0") + @classmethod + def disable(cls): + """Disable kernel caching globally.""" + cls._enabled = False -# CPU Utilizations for Auto-Tuning, default is 0.9 -TILELANG_AUTO_TUNING_CPU_UTILITIES: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_UTILITIES", - "0.9") + @classmethod + def is_enabled(cls) -> bool: + """Return current cache state.""" + return cls._enabled -# CPU COUNTS for Auto-Tuning, default is -1, -# which will use TILELANG_AUTO_TUNING_CPU_UTILITIES * get_available_cpu_count() -TILELANG_AUTO_TUNING_CPU_COUNTS: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") -# Max CPU Count for Auto-Tuning, default is 100 -TILELANG_AUTO_TUNING_MAX_CPU_COUNT: str = os.environ.get("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") +@dataclass +class EnvVar: + """ + Descriptor for managing access to a single environment variable. + + Purpose + ------- + In many projects, access to environment variables is scattered across the codebase: + * `os.environ.get(...)` calls are repeated everywhere + * Default values are hard-coded in multiple places + * Overriding env vars for tests/debugging is messy + * There's no central place to see all environment variables a package uses + + This descriptor solves those issues by: + 1. Centralizing the definition of the variable's **key** and **default value** + 2. Allowing *dynamic* reads from `os.environ` so changes take effect immediately + 3. Supporting **forced overrides** at runtime (for unit tests or debugging) + 4. Logging a warning when a forced value is used (helps detect unexpected overrides) + 5. Optionally syncing forced values back to `os.environ` if global consistency is desired + + How it works + ------------ + - This is a `dataclass` implementing the descriptor protocol (`__get__`, `__set__`) + - When used as a class attribute, `instance.attr` triggers `__get__()` + → returns either the forced override or the live value from `os.environ` + - Assigning to the attribute (`instance.attr = value`) triggers `__set__()` + → stores `_forced_value` for future reads + - You may uncomment the `os.environ[...] = value` line in `__set__` if you want + the override to persist globally in the process + + Example + ------- + ```python + class Environment: + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "0") + + env = Environment() + print(cfg.TILELANG_PRINT_ON_COMPILATION) # Reads from os.environ (with default fallback) + cfg.TILELANG_PRINT_ON_COMPILATION = "1" # Forces value to "1" until changed/reset + ``` + + Benefits + -------- + * Centralizes all env-var keys and defaults in one place + * Live, up-to-date reads (no stale values after `import`) + * Testing convenience (override without touching the real env) + * Improves IDE discoverability and type hints + * Avoids hardcoding `os.environ.get(...)` in multiple places + """ -# SETUP ENVIRONMENT VARIABLES -CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") -", which may lead to compilation bugs when utilize tilelang backend." -COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = ( - "Composable Kernel is not installed or found in the expected path") -", which may lead to compilation bugs when utilize tilelang backend." -TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") -", which may lead to compilation bugs when utilize tilelang backend." -TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path") + key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION") + default: str # Default value if the environment variable is not set + _forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging) + + def get(self): + if self._forced_value is not None: + return self._forced_value + return os.environ.get(self.key, self.default) + + def __get__(self, instance, owner): + """ + Called when the attribute is accessed. + 1. If a forced value is set, return it and log a warning + 2. Otherwise, look up the value in os.environ; return the default if missing + """ + return self.get() + + def __set__(self, instance, value): + """ + Called when the attribute is assigned to. + Stores the value as a runtime override (forced value). + Optionally, you can also sync this into os.environ for global effect. + """ + self._forced_value = value + # Uncomment the following line if you want the override to persist globally: + # os.environ[self.key] = value + + +# Cache control API (wrap CacheState) +enable_cache = CacheState.enable +disable_cache = CacheState.disable +is_cache_enabled = CacheState.is_enabled -SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0") -# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path -TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) +# Utility function for environment variables with defaults +# Assuming EnvVar and CacheState are defined elsewhere +class Environment: + """ + Environment configuration for TileLang. + Handles CUDA/ROCm detection, integration paths, template/cache locations, + auto-tuning configs, and build options. + """ -if TVM_IMPORT_PYTHON_PATH is not None: - os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "") - sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) + # CUDA/ROCm home directories + CUDA_HOME = _find_cuda_home() + ROCM_HOME = _find_rocm_home() + + # Path to the TileLang package root + TILELANG_PACKAGE_PATH = pathlib.Path(__file__).resolve().parent + + # External library include paths + CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None) + COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None) + + # TVM integration + TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) + TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None) + + # TileLang resources + TILELANG_TEMPLATE_PATH = EnvVar("TL_TEMPLATE_PATH", None) + TILELANG_CACHE_DIR = EnvVar("TILELANG_CACHE_DIR", os.path.expanduser("~/.tilelang/cache")) + TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp")) + + # Kernel Build options + TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", + "1") # print kernel name on compile + TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set + + # Auto-tuning settings + TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES", + "0.9") # percent of CPUs used + TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", + "-1") # -1 means auto + TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", + "-1") # -1 means no limit + + # TVM integration + SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") + TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) + + def _initialize_torch_cuda_arch_flags(self) -> None: + """ + Detect target CUDA architecture and set TORCH_CUDA_ARCH_LIST + to ensure PyTorch extensions are built for the proper GPU arch. + """ + from tilelang.contrib import nvcc + from tilelang.utils.target import determine_target + + target = determine_target(return_object=True) # get target GPU + compute_version = nvcc.get_target_compute_version(target) # e.g. "8.6" + major, minor = nvcc.parse_compute_version(compute_version) # split to (8, 6) + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" # set env var for PyTorch + + # Cache control API (wrap CacheState) + def is_cache_enabled(self) -> bool: + return CacheState.is_enabled() + + def enable_cache(self) -> None: + CacheState.enable() + + def disable_cache(self) -> None: + CacheState.disable() + + +# Instantiate as a global configuration object +env = Environment() + +# Export CUDA_HOME and ROCM_HOME, both are static variables +# after initialization. +CUDA_HOME = env.CUDA_HOME +ROCM_HOME = env.ROCM_HOME + +# Initialize TVM paths +if env.TVM_IMPORT_PYTHON_PATH is not None: + os.environ["PYTHONPATH"] = env.TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, env.TVM_IMPORT_PYTHON_PATH) else: install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = ( install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, install_tvm_path + "/python") - TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python" + env.TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python" develop_tvm_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") @@ -123,7 +259,7 @@ def _initialize_torch_cuda_arch_flags(): os.environ["PYTHONPATH"] = ( develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, develop_tvm_path + "/python") - TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python" + env.TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python" develop_tvm_library_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm") @@ -136,14 +272,15 @@ def _initialize_torch_cuda_arch_flags(): else: logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE) # pip install build library path - lib_path = os.path.join(TILELANG_PACKAGE_PATH, "lib") + lib_path = os.path.join(env.TILELANG_PACKAGE_PATH, "lib") existing_path = os.environ.get("TVM_LIBRARY_PATH") if existing_path: os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}" else: os.environ["TVM_LIBRARY_PATH"] = lib_path - TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None) + env.TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None) +# Initialize CUTLASS paths if os.environ.get("TL_CUTLASS_PATH", None) is None: install_cutlass_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") @@ -151,13 +288,14 @@ def _initialize_torch_cuda_arch_flags(): os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") if os.path.exists(install_cutlass_path): os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" - CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include" + env.CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include" elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path): os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" - CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include" + env.CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include" else: logger.warning(CUTLASS_NOT_FOUND_MESSAGE) +# Initialize COMPOSABLE_KERNEL paths if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: install_ck_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel") @@ -165,63 +303,27 @@ def _initialize_torch_cuda_arch_flags(): os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel") if os.path.exists(install_ck_path): os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include" - COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include" + env.COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include" elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path): os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include" - COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include" + env.COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include" else: logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE) +# Initialize TL_TEMPLATE_PATH if os.environ.get("TL_TEMPLATE_PATH", None) is None: install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src") develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src") if os.path.exists(install_tl_template_path): os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path - TILELANG_TEMPLATE_PATH = install_tl_template_path + env.TILELANG_TEMPLATE_PATH = install_tl_template_path elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path): os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path - TILELANG_TEMPLATE_PATH = develop_tl_template_path + env.TILELANG_TEMPLATE_PATH = develop_tl_template_path else: logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) - -# Cache control -class CacheState: - """Class to manage global kernel caching state.""" - _enabled = True - - @classmethod - def enable(cls): - """Enable kernel caching globally.""" - cls._enabled = True - - @classmethod - def disable(cls): - """Disable kernel caching globally.""" - cls._enabled = False - - @classmethod - def is_enabled(cls) -> bool: - """Return current cache state.""" - return cls._enabled - - -# Replace the old functions with class methods -enable_cache = CacheState.enable -disable_cache = CacheState.disable -is_cache_enabled = CacheState.is_enabled - -__all__ = [ - "CUTLASS_INCLUDE_DIR", - "COMPOSABLE_KERNEL_INCLUDE_DIR", - "TVM_PYTHON_PATH", - "TVM_LIBRARY_PATH", - "TILELANG_TEMPLATE_PATH", - "CUDA_HOME", - "ROCM_HOME", - "TILELANG_CACHE_DIR", - "enable_cache", - "disable_cache", - "is_cache_enabled", - "_initialize_torch_cuda_arch_flags", -] +# Export static variables after initialization. +CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR +COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR +TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 3a2de02ef..15cb47b62 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -4,9 +4,9 @@ from tvm.tir import PrimFunc import tilelang -from tilelang import tvm as tvm +from tilelang import tvm +from tilelang import env from tilelang.engine.param import CompiledArtifact, KernelParam -from tilelang.env import TILELANG_PRINT_ON_COMPILATION from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, NVRTCKernelAdapter, TorchDLPackKernelAdapter) from tilelang.profiler import Profiler, TensorSupplyType @@ -114,7 +114,7 @@ def __init__( # Print log on compilation starts # NOTE(Chenggang): printing could let the training/inference framework easier to know # whether the communication timeout is from compilation - if TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"): + if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"): print(f"TileLang begins to compile kernel `{func.__name__}` with `{out_idx=}`") # Compile the TileLang function and create a kernel adapter for execution. diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index 8cc768467..cc7975ae8 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -2,12 +2,12 @@ import torch import warnings from torch.utils.cpp_extension import load, _import_module_from_library -from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR +from tilelang import env # Define paths -compress_util = os.path.join(TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") +compress_util = os.path.join(env.TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") # Cache directory for compiled extensions -_CACHE_DIR = os.path.join(TILELANG_CACHE_DIR, "sparse_compressor") +_CACHE_DIR = os.path.join(env.TILELANG_CACHE_DIR, "sparse_compressor") os.makedirs(_CACHE_DIR, exist_ok=True) @@ -22,9 +22,8 @@ def _get_cached_lib(): # If loading fails, recompile pass - from tilelang.env import _initialize_torch_cuda_arch_flags # Set TORCH_CUDA_ARCH_LIST - _initialize_torch_cuda_arch_flags() + env._initialize_torch_cuda_arch_flags() # Compile if not cached or loading failed return load( @@ -34,8 +33,8 @@ def _get_cached_lib(): '-O2', '-std=c++17', '-lineinfo', - f'-I{CUTLASS_INCLUDE_DIR}', - f'-I{CUTLASS_INCLUDE_DIR}/../tools/util/include', + f'-I{env.CUTLASS_INCLUDE_DIR}', + f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include', '-arch=sm_90', ], build_directory=_CACHE_DIR,