Skip to content

Commit d321a2c

Browse files
authored
Conditional ROCm kernel build (#2839)
* conditional kernel build * lint * Remove GPU architecture check for ROCm in setup.py and add a TODO for supporting other ROCm GPUs.
1 parent 6a6a672 commit d321a2c

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

setup.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def get_extensions():
433433
extra_link_args.append("/DEBUG")
434434

435435
rocm_sparse_marlin_supported = False
436+
rocm_tiled_layout_supported = False
436437
if use_rocm:
437438
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
438439
found_col16 = False
@@ -488,8 +489,11 @@ def get_extensions():
488489
# Define ROCm source directories
489490
rocm_source_dirs = [
490491
os.path.join(extensions_dir, "rocm", "swizzle"),
491-
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout"),
492492
]
493+
if rocm_tiled_layout_supported:
494+
rocm_source_dirs.append(
495+
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout")
496+
)
493497
if rocm_sparse_marlin_supported:
494498
rocm_source_dirs.extend([os.path.join(extensions_dir, "cuda", "sparse_marlin")])
495499

@@ -512,14 +516,8 @@ def get_extensions():
512516
sources = [s for s in sources if s not in mxfp8_sources_to_exclude]
513517

514518
# TOOD: Remove this and use what CUDA has once we fix all the builds.
519+
# TODO: Add support for other ROCm GPUs
515520
if use_rocm:
516-
# Add ROCm GPU architecture check
517-
gpu_arch = None
518-
if torch.cuda.is_available():
519-
gpu_arch = torch.cuda.get_device_properties(0).name
520-
if gpu_arch and gpu_arch != "gfx942":
521-
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
522-
print("Currently only gfx942 is supported. Compiling only for gfx942.")
523521
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
524522
sources += rocm_sources
525523
else:

0 commit comments

Comments
 (0)