Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_rocm_version():
return LooseVersion("5.0.0")


def get_tilelang_version(with_cuda=True, with_system_info=True) -> str:
def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=False) -> str:
version = find_version(get_path(".", "VERSION"))
local_version_parts = []
if with_system_info:
Expand All @@ -153,6 +153,18 @@ def get_tilelang_version(with_cuda=True, with_system_info=True) -> str:

if local_version_parts:
version += f"+{'.'.join(local_version_parts)}"

if with_commit_id:
commit_id = None
try:
commit_id = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
stderr=subprocess.DEVNULL,
encoding='utf-8').strip()
except subprocess.SubprocessError as error:
raise RuntimeError("Failed to get git commit id") from error
if commit_id:
version += f"+{commit_id}"

return version


Expand Down Expand Up @@ -476,6 +488,18 @@ def run(self):
for item in TL_CONFIG_ITEMS:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item)
# if is VERSION file, replace the content with the new version with commit id
if not PYPI_BUILD and item == "VERSION":
version = get_tilelang_version(
with_cuda=False, with_system_info=False, with_commit_id=True)
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
with open(os.path.join(target_dir, item), "w") as f:
print(f"Writing {version} to {os.path.join(target_dir, item)}")
f.write(version)
continue

if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
Expand All @@ -492,7 +516,7 @@ class TileLangSdistCommand(sdist):
def make_distribution(self):
self.distribution.metadata.name = PACKAGE_NAME
self.distribution.metadata.version = get_tilelang_version(
with_cuda=False, with_system_info=False)
with_cuda=False, with_system_info=False, with_commit_id=False)
super().make_distribution()


Expand Down Expand Up @@ -575,9 +599,10 @@ def run(self):
# Check if CMake is installed and accessible by attempting to run 'cmake --version'.
try:
subprocess.check_output(["cmake", "--version"])
except OSError as e:
except OSError as error:
# If CMake is not found, raise an error.
raise RuntimeError("CMake must be installed to build the following extensions") from e
raise RuntimeError(
"CMake must be installed to build the following extensions") from error

update_submodules()

Expand Down
5 changes: 3 additions & 2 deletions src/transform/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
auto int_bound = analyzer_->const_int_bound(index);
DataType dtype = index->dtype;
if (dtype.is_int() && dtype.bits() < 64) {
int64_t max_value = int_bound->max_value + 1;
int64_t max_value = int_bound->max_value;
int64_t min_value = int_bound->min_value;
const int64_t type_max = (1LL << (dtype.bits() - 1));
const int64_t type_min = -(1LL << (dtype.bits() - 1));
if (max_value >= type_max || min_value < type_min) {

if (max_value >= (type_max - 1) || min_value < type_min) {
Int64Promoter promoter;
for (auto &index : flattened_indices) {
safe_indices.push_back(promoter(index));
Expand Down
31 changes: 19 additions & 12 deletions tilelang/autotuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ def run_with_timeout(func, timeout, *args, **kwargs):
logger.setLevel(logging.DEBUG)
logger.propagate = False

formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')

file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Lazy handler initialization flag
_logger_handlers_initialized = False


def _init_logger_handlers():
global _logger_handlers_initialized
if _logger_handlers_initialized:
return
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
_logger_handlers_initialized = True


@dataclass(frozen=True)
Expand Down Expand Up @@ -241,6 +247,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
_init_logger_handlers()
sig = inspect.signature(self.fn)
keys = list(sig.parameters.keys())
bound_args = sig.bind()
Expand Down
4 changes: 4 additions & 0 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging

from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.version import __version__

KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
Expand Down Expand Up @@ -89,6 +90,7 @@ def _generate_key(
"""
func_binary = cloudpickle.dumps(func.script())
key_data = {
"version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple(
Expand Down Expand Up @@ -149,6 +151,8 @@ def cached(
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.jit` instead of direct kernel caching.")
return self._memory_cache[key]

# Then check disk cache
Expand Down
21 changes: 21 additions & 0 deletions tilelang/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Licensed under the MIT License.

import os
import subprocess
from typing import Union

# Get the absolute path of the current Python script's directory
current_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -25,5 +27,24 @@
with open(version_file_path, "r") as version_file:
__version__ = version_file.read().strip()


def get_git_commit_id() -> Union[str, None]:
"""Get the current git commit hash.

Returns:
str | None: The git commit hash if available, None otherwise.
"""
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'],
stderr=subprocess.DEVNULL,
encoding='utf-8').strip()
except subprocess.SubprocessError:
return None


# Append git commit hash to version if not already present
if "+" not in __version__ and (commit_id := get_git_commit_id()):
__version__ = f"{__version__}+{commit_id}"

# Define the public API for the module
__all__ = ["__version__"]
Loading