Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
import multiprocessing
from setuptools.command.build_ext import build_ext
import importlib
import logging

# Configure logging with basic settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')

logger = logging.getLogger(__name__)

# Environment variables False/True
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
Expand Down Expand Up @@ -195,7 +204,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
download_url = f"{base_url}/{file_name}"

# Download the file
print(f"Downloading {file_name} from {download_url}")
logger.info(f"Downloading {file_name} from {download_url}")
with urllib.request.urlopen(download_url) as response:
if response.status != 200:
raise Exception(f"Download failed with status code {response.status}")
Expand All @@ -208,11 +217,11 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
os.remove(os.path.join(extract_path, file_name))

# Extract the file
print(f"Extracting {file_name} to {extract_path}")
logger.info(f"Extracting {file_name} to {extract_path}")
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
tar.extractall(path=extract_path)

print("Download and extraction completed successfully.")
logger.info("Download and extraction completed successfully.")
return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", "")))


Expand All @@ -238,7 +247,7 @@ def is_git_repo():
return False

if not is_git_repo():
print("Info: Not a git repository, skipping submodule update.")
logger.info("Info: Not a git repository, skipping submodule update.")
return

try:
Expand Down Expand Up @@ -288,7 +297,15 @@ def patch_libs(libpath):
and have a hard-coded rpath.
Set rpath to the directory of libs so auditwheel works well.
"""
subprocess.run(['patchelf', '--set-rpath', '$ORIGIN', libpath])
# check if patchelf is installed
# find patchelf in the system
patchelf_path = shutil.which("patchelf")
if not patchelf_path:
logger.warning(
"patchelf is not installed, which is required for auditwheel to work for compatible wheels."
)
return
subprocess.run([patchelf_path, '--set-rpath', '$ORIGIN', libpath])


class TileLangBuilPydCommand(build_py):
Expand All @@ -302,11 +319,11 @@ def run(self):
ext_modules = build_ext_cmd.extensions
for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
print(f"Extension {ext.name} output directory: {extdir}")
logger.info(f"Extension {ext.name} output directory: {extdir}")

ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
print(f"Build temp directory: {build_temp_dir}")
logger.info(f"Extension output directory (parent): {ext_output_dir}")
logger.info(f"Build temp directory: {build_temp_dir}")

# copy cython files
CYTHON_SRC = [
Expand Down Expand Up @@ -373,12 +390,12 @@ def run(self):
os.makedirs(target_dir_release, exist_ok=True)
os.makedirs(target_dir_develop, exist_ok=True)
shutil.copy2(source_lib_file, target_dir_release)
print(f"Copied {source_lib_file} to {target_dir_release}")
logger.info(f"Copied {source_lib_file} to {target_dir_release}")
shutil.copy2(source_lib_file, target_dir_develop)
print(f"Copied {source_lib_file} to {target_dir_develop}")
logger.info(f"Copied {source_lib_file} to {target_dir_develop}")
os.remove(source_lib_file)
else:
print(f"WARNING: {item} not found in any expected directories!")
logger.info(f"WARNING: {item} not found in any expected directories!")

TVM_CONFIG_ITEMS = [
f"{build_temp_dir}/config.cmake",
Expand All @@ -394,7 +411,7 @@ def run(self):
if os.path.exists(source_dir):
shutil.copy2(source_dir, target_dir)
else:
print(f"INFO: {source_dir} does not exist.")
logger.info(f"INFO: {source_dir} does not exist.")

TVM_PACAKGE_ITEMS = [
"3rdparty/tvm/src",
Expand Down Expand Up @@ -489,18 +506,18 @@ class TileLangDevelopCommand(develop):
"""

def run(self):
print("Running TileLangDevelopCommand")
logger.info("Running TileLangDevelopCommand")
# 1. Build the C/C++ extension modules
self.run_command("build_ext")

build_ext_cmd = self.get_finalized_command("build_ext")
ext_modules = build_ext_cmd.extensions
for ext in ext_modules:
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
print(f"Extension {ext.name} output directory: {extdir}")
logger.info(f"Extension {ext.name} output directory: {extdir}")

ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
logger.info(f"Extension output directory (parent): {ext_output_dir}")

# Copy the built TVM to the package directory
TVM_PREBUILD_ITEMS = [
Expand All @@ -524,7 +541,7 @@ def run(self):
# remove the original file
os.remove(source_lib_file)
else:
print(f"INFO: {source_lib_file} does not exist.")
logger.info(f"INFO: {source_lib_file} does not exist.")


class CMakeExtension(Extension):
Expand Down
30 changes: 27 additions & 3 deletions tilelang/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def get_target_compute_version(target=None):
Returns
-------
compute_version : str
compute capability of a GPU (e.g. "8.6")
compute capability of a GPU (e.g. "8.6" or "9.0")
"""
# 1. input target object
# 2. Target.current()
Expand All @@ -279,10 +279,17 @@ def get_target_compute_version(target=None):
arch = target.arch.split("_")[1]
if len(arch) == 2:
major, minor = arch
# Handle old format like sm_89
return major + "." + minor
elif len(arch) == 3:
# This is for arch like "sm_90a"
major, minor, suffix = arch
major = int(arch[0])
if major < 2:
major = arch[0:2]
minor = arch[2]
return major + "." + minor
else:
# This is for arch like "sm_90a"
major, minor, suffix = arch
return major + "." + minor + "." + suffix

# 3. GPU compute version
Expand Down Expand Up @@ -416,6 +423,23 @@ def have_fp8(compute_version):
return any(conditions)


@tvm._ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not

Parameters
----------
target : tvm.target.Target
The compilation target
"""
compute_version = get_target_compute_version(target)
major, minor = parse_compute_version(compute_version)
# TMA is supported in Ada Lovelace (9.0) or later architectures.
conditions = [False]
conditions.append(major >= 9)
return any(conditions)


def get_nvcc_compiler() -> str:
"""Get the path to the nvcc compiler"""
return os.path.join(find_cuda_path(), "bin", "nvcc")
7 changes: 3 additions & 4 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma
from typing import Optional

SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}


def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in SUPPORTED_TMA_ARCHS:
if not have_tma(target):
return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
Expand All @@ -22,7 +21,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,


def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in SUPPORTED_TMA_ARCHS
return have_tma(target)


def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
Expand Down