Skip to content

Commit 9b3d135

Browse files
authored
[Refactor] Skip patchelf if not installed (#477)
* [Refactor] Enhance TMA barrier validation and support for additional architectures * Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`. * Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic. * Enhance logging in setup.py and refactor TMA architecture checks in phase.py * Added logging configuration to setup.py, replacing print statements with logger for better traceability. * Updated download and extraction functions to use logger for status messages. * Refactored TMA architecture checks in phase.py to utilize the new `have_tma` function for improved clarity and maintainability. * Introduced support for additional compute capabilities in nvcc.py, including TMA support checks. * Update documentation for get_target_compute_version to reflect correct GPU compute capability range * Refactor have_tma function to accept tvm.target.Target instead of compute_version * Updated the `have_tma` function in nvcc.py to take a `target` parameter, improving clarity and usability. * Adjusted calls to `have_tma` in phase.py to pass the target directly, enhancing maintainability and consistency in TMA support checks.
1 parent 912ead3 commit 9b3d135

File tree

3 files changed

+63
-23
lines changed

3 files changed

+63
-23
lines changed

setup.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121
import multiprocessing
2222
from setuptools.command.build_ext import build_ext
2323
import importlib
24+
import logging
25+
26+
# Configure logging with basic settings
27+
logging.basicConfig(
28+
level=logging.INFO,
29+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
30+
datefmt='%Y-%m-%d %H:%M:%S')
31+
32+
logger = logging.getLogger(__name__)
2433

2534
# Environment variables False/True
2635
PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true"
@@ -195,7 +204,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
195204
download_url = f"{base_url}/{file_name}"
196205

197206
# Download the file
198-
print(f"Downloading {file_name} from {download_url}")
207+
logger.info(f"Downloading {file_name} from {download_url}")
199208
with urllib.request.urlopen(download_url) as response:
200209
if response.status != 200:
201210
raise Exception(f"Download failed with status code {response.status}")
@@ -208,11 +217,11 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"
208217
os.remove(os.path.join(extract_path, file_name))
209218

210219
# Extract the file
211-
print(f"Extracting {file_name} to {extract_path}")
220+
logger.info(f"Extracting {file_name} to {extract_path}")
212221
with tarfile.open(fileobj=BytesIO(file_content), mode="r:xz") as tar:
213222
tar.extractall(path=extract_path)
214223

215-
print("Download and extraction completed successfully.")
224+
logger.info("Download and extraction completed successfully.")
216225
return os.path.abspath(os.path.join(extract_path, file_name.replace(".tar.xz", "")))
217226

218227

@@ -238,7 +247,7 @@ def is_git_repo():
238247
return False
239248

240249
if not is_git_repo():
241-
print("Info: Not a git repository, skipping submodule update.")
250+
logger.info("Info: Not a git repository, skipping submodule update.")
242251
return
243252

244253
try:
@@ -288,7 +297,15 @@ def patch_libs(libpath):
288297
and have a hard-coded rpath.
289298
Set rpath to the directory of libs so auditwheel works well.
290299
"""
291-
subprocess.run(['patchelf', '--set-rpath', '$ORIGIN', libpath])
300+
# check if patchelf is installed
301+
# find patchelf in the system
302+
patchelf_path = shutil.which("patchelf")
303+
if not patchelf_path:
304+
logger.warning(
305+
"patchelf is not installed, which is required for auditwheel to work for compatible wheels."
306+
)
307+
return
308+
subprocess.run([patchelf_path, '--set-rpath', '$ORIGIN', libpath])
292309

293310

294311
class TileLangBuilPydCommand(build_py):
@@ -302,11 +319,11 @@ def run(self):
302319
ext_modules = build_ext_cmd.extensions
303320
for ext in ext_modules:
304321
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
305-
print(f"Extension {ext.name} output directory: {extdir}")
322+
logger.info(f"Extension {ext.name} output directory: {extdir}")
306323

307324
ext_output_dir = os.path.dirname(extdir)
308-
print(f"Extension output directory (parent): {ext_output_dir}")
309-
print(f"Build temp directory: {build_temp_dir}")
325+
logger.info(f"Extension output directory (parent): {ext_output_dir}")
326+
logger.info(f"Build temp directory: {build_temp_dir}")
310327

311328
# copy cython files
312329
CYTHON_SRC = [
@@ -373,12 +390,12 @@ def run(self):
373390
os.makedirs(target_dir_release, exist_ok=True)
374391
os.makedirs(target_dir_develop, exist_ok=True)
375392
shutil.copy2(source_lib_file, target_dir_release)
376-
print(f"Copied {source_lib_file} to {target_dir_release}")
393+
logger.info(f"Copied {source_lib_file} to {target_dir_release}")
377394
shutil.copy2(source_lib_file, target_dir_develop)
378-
print(f"Copied {source_lib_file} to {target_dir_develop}")
395+
logger.info(f"Copied {source_lib_file} to {target_dir_develop}")
379396
os.remove(source_lib_file)
380397
else:
381-
print(f"WARNING: {item} not found in any expected directories!")
398+
logger.info(f"WARNING: {item} not found in any expected directories!")
382399

383400
TVM_CONFIG_ITEMS = [
384401
f"{build_temp_dir}/config.cmake",
@@ -394,7 +411,7 @@ def run(self):
394411
if os.path.exists(source_dir):
395412
shutil.copy2(source_dir, target_dir)
396413
else:
397-
print(f"INFO: {source_dir} does not exist.")
414+
logger.info(f"INFO: {source_dir} does not exist.")
398415

399416
TVM_PACAKGE_ITEMS = [
400417
"3rdparty/tvm/src",
@@ -489,18 +506,18 @@ class TileLangDevelopCommand(develop):
489506
"""
490507

491508
def run(self):
492-
print("Running TileLangDevelopCommand")
509+
logger.info("Running TileLangDevelopCommand")
493510
# 1. Build the C/C++ extension modules
494511
self.run_command("build_ext")
495512

496513
build_ext_cmd = self.get_finalized_command("build_ext")
497514
ext_modules = build_ext_cmd.extensions
498515
for ext in ext_modules:
499516
extdir = build_ext_cmd.get_ext_fullpath(ext.name)
500-
print(f"Extension {ext.name} output directory: {extdir}")
517+
logger.info(f"Extension {ext.name} output directory: {extdir}")
501518

502519
ext_output_dir = os.path.dirname(extdir)
503-
print(f"Extension output directory (parent): {ext_output_dir}")
520+
logger.info(f"Extension output directory (parent): {ext_output_dir}")
504521

505522
# Copy the built TVM to the package directory
506523
TVM_PREBUILD_ITEMS = [
@@ -524,7 +541,7 @@ def run(self):
524541
# remove the original file
525542
os.remove(source_lib_file)
526543
else:
527-
print(f"INFO: {source_lib_file} does not exist.")
544+
logger.info(f"INFO: {source_lib_file} does not exist.")
528545

529546

530547
class CMakeExtension(Extension):

tilelang/contrib/nvcc.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_target_compute_version(target=None):
270270
Returns
271271
-------
272272
compute_version : str
273-
compute capability of a GPU (e.g. "8.6")
273+
compute capability of a GPU (e.g. "8.6" or "9.0")
274274
"""
275275
# 1. input target object
276276
# 2. Target.current()
@@ -279,10 +279,17 @@ def get_target_compute_version(target=None):
279279
arch = target.arch.split("_")[1]
280280
if len(arch) == 2:
281281
major, minor = arch
282+
# Handle old format like sm_89
282283
return major + "." + minor
283284
elif len(arch) == 3:
284-
# This is for arch like "sm_90a"
285-
major, minor, suffix = arch
285+
major = int(arch[0])
286+
if major < 2:
287+
major = arch[0:2]
288+
minor = arch[2]
289+
return major + "." + minor
290+
else:
291+
# This is for arch like "sm_90a"
292+
major, minor, suffix = arch
286293
return major + "." + minor + "." + suffix
287294

288295
# 3. GPU compute version
@@ -416,6 +423,23 @@ def have_fp8(compute_version):
416423
return any(conditions)
417424

418425

426+
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
427+
def have_tma(target):
428+
"""Whether TMA support is provided in the specified compute capability or not
429+
430+
Parameters
431+
----------
432+
target : tvm.target.Target
433+
The compilation target
434+
"""
435+
compute_version = get_target_compute_version(target)
436+
major, minor = parse_compute_version(compute_version)
437+
# TMA is supported in Ada Lovelace (9.0) or later architectures.
438+
conditions = [False]
439+
conditions.append(major >= 9)
440+
return any(conditions)
441+
442+
419443
def get_nvcc_compiler() -> str:
420444
"""Get the path to the nvcc compiler"""
421445
return os.path.join(find_cuda_path(), "bin", "nvcc")

tilelang/engine/phase.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
from tvm.target import Target
55
import tilelang
66
from tilelang.transform import PassContext
7+
from tilelang.contrib.nvcc import have_tma
78
from typing import Optional
89

9-
SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}
10-
1110

1211
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1312
target: Optional[Target] = None) -> bool:
1413
if pass_ctx is None:
1514
pass_ctx = tilelang.transform.get_pass_context()
16-
if target.arch not in SUPPORTED_TMA_ARCHS:
15+
if not have_tma(target):
1716
return False
1817
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
1918
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
@@ -22,7 +21,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
2221

2322

2423
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
25-
return target.arch in SUPPORTED_TMA_ARCHS
24+
return have_tma(target)
2625

2726

2827
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:

0 commit comments

Comments
 (0)