diff --git a/setup.py b/setup.py index e43e8ba7c..6c14aa452 100644 --- a/setup.py +++ b/setup.py @@ -133,7 +133,7 @@ def get_rocm_version(): return LooseVersion("5.0.0") -def get_tilelang_version(with_cuda=True, with_system_info=True) -> str: +def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str: version = find_version(get_path(".", "VERSION")) local_version_parts = [] if with_system_info: @@ -153,6 +153,18 @@ def get_tilelang_version(with_cuda=True, with_system_info=True) -> str: if local_version_parts: version += f"+{'.'.join(local_version_parts)}" + + if with_commit_id: + commit_id = None + try: + commit_id = subprocess.check_output(['git', 'rev-parse', 'HEAD'], + stderr=subprocess.DEVNULL, + encoding='utf-8').strip() + except subprocess.SubprocessError as error: + raise RuntimeError("Failed to get git commit id") from error + if commit_id: + version += f"+{commit_id}" + return version @@ -476,6 +488,18 @@ def run(self): for item in TL_CONFIG_ITEMS: source_dir = os.path.join(ROOT_DIR, item) target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + # if is VERSION file, replace the content with the new version with commit id + if not PYPI_BUILD and item == "VERSION": + version = get_tilelang_version( + with_cuda=False, with_system_info=False, with_commit_id=True) + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + with open(os.path.join(target_dir, item), "w") as f: + print(f"Writing {version} to {os.path.join(target_dir, item)}") + f.write(version) + continue + if os.path.isdir(source_dir): self.mkpath(target_dir) distutils.dir_util.copy_tree(source_dir, target_dir) @@ -492,7 +516,7 @@ class TileLangSdistCommand(sdist): def make_distribution(self): self.distribution.metadata.name = PACKAGE_NAME self.distribution.metadata.version = get_tilelang_version( - with_cuda=False, with_system_info=False) + with_cuda=False, with_system_info=False, with_commit_id=False) super().make_distribution() @@ -575,9 +599,10 @@ def run(self): # Check if CMake is installed and accessible by attempting to run 'cmake --version'. try: subprocess.check_output(["cmake", "--version"]) - except OSError as e: + except OSError as error: # If CMake is not found, raise an error. - raise RuntimeError("CMake must be installed to build the following extensions") from e + raise RuntimeError( + "CMake must be installed to build the following extensions") from error update_submodules() diff --git a/src/transform/flatten_buffer.cc b/src/transform/flatten_buffer.cc index 8a93d4a0d..190b98db8 100644 --- a/src/transform/flatten_buffer.cc +++ b/src/transform/flatten_buffer.cc @@ -281,11 +281,12 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { auto int_bound = analyzer_->const_int_bound(index); DataType dtype = index->dtype; if (dtype.is_int() && dtype.bits() < 64) { - int64_t max_value = int_bound->max_value + 1; + int64_t max_value = int_bound->max_value; int64_t min_value = int_bound->min_value; const int64_t type_max = (1LL << (dtype.bits() - 1)); const int64_t type_min = -(1LL << (dtype.bits() - 1)); - if (max_value >= type_max || min_value < type_min) { + + if (max_value >= (type_max - 1) || min_value < type_min) { Int64Promoter promoter; for (auto &index : flattened_indices) { safe_indices.push_back(promoter(index)); diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 716d0208f..159d3aad9 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -46,18 +46,24 @@ def run_with_timeout(func, timeout, *args, **kwargs): logger.setLevel(logging.DEBUG) logger.propagate = False -formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') - -file_handler = logging.FileHandler('autotuner.log', mode='w') -file_handler.setLevel(logging.DEBUG) -file_handler.setFormatter(formatter) - -console_handler = logging.StreamHandler(sys.stdout) -console_handler.setLevel(logging.INFO) -console_handler.setFormatter(formatter) - -logger.addHandler(file_handler) -logger.addHandler(console_handler) +# Lazy handler initialization flag +_logger_handlers_initialized = False + + +def _init_logger_handlers(): + global _logger_handlers_initialized + if _logger_handlers_initialized: + return + formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') + file_handler = logging.FileHandler('autotuner.log', mode='w') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + _logger_handlers_initialized = True @dataclass(frozen=True) @@ -241,6 +247,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): Returns: AutotuneResult: Results of the auto-tuning process. """ + _init_logger_handlers() sig = inspect.signature(self.fn) keys = list(sig.parameters.keys()) bound_args = sig.bind() diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index b24f32eb3..04520c558 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -17,6 +17,7 @@ import logging from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled +from tilelang.version import __version__ KERNEL_PATH = "kernel.cu" WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" @@ -89,6 +90,7 @@ def _generate_key( """ func_binary = cloudpickle.dumps(func.script()) key_data = { + "version": __version__, "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]), "args_repr": tuple( @@ -149,6 +151,8 @@ def cached( with self._lock: # First check in-memory cache if key in self._memory_cache: + self.logger.warning("Found kernel in memory cache. For better performance," \ + " consider using `@tilelang.jit` instead of direct kernel caching.") return self._memory_cache[key] # Then check disk cache diff --git a/tilelang/version.py b/tilelang/version.py index 4d1c751c5..651de2c3d 100644 --- a/tilelang/version.py +++ b/tilelang/version.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. import os +import subprocess +from typing import Union # Get the absolute path of the current Python script's directory current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -25,5 +27,24 @@ with open(version_file_path, "r") as version_file: __version__ = version_file.read().strip() + +def get_git_commit_id() -> Union[str, None]: + """Get the current git commit hash. + + Returns: + str | None: The git commit hash if available, None otherwise. + """ + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD'], + stderr=subprocess.DEVNULL, + encoding='utf-8').strip() + except subprocess.SubprocessError: + return None + + +# Append git commit hash to version if not already present +if "+" not in __version__ and (commit_id := get_git_commit_id()): + __version__ = f"{__version__}+{commit_id}" + # Define the public API for the module __all__ = ["__version__"]