@@ -433,6 +433,7 @@ def get_extensions():
433433 extra_link_args .append ("/DEBUG" )
434434
435435 rocm_sparse_marlin_supported = False
436+ rocm_tiled_layout_supported = False
436437 if use_rocm :
437438 # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
438439 found_col16 = False
@@ -488,8 +489,11 @@ def get_extensions():
488489 # Define ROCm source directories
489490 rocm_source_dirs = [
490491 os .path .join (extensions_dir , "rocm" , "swizzle" ),
491- os .path .join (extensions_dir , "cuda" , "tensor_core_tiled_layout" ),
492492 ]
493+ if rocm_tiled_layout_supported :
494+ rocm_source_dirs .append (
495+ os .path .join (extensions_dir , "cuda" , "tensor_core_tiled_layout" )
496+ )
493497 if rocm_sparse_marlin_supported :
494498 rocm_source_dirs .extend ([os .path .join (extensions_dir , "cuda" , "sparse_marlin" )])
495499
@@ -512,14 +516,8 @@ def get_extensions():
512516 sources = [s for s in sources if s not in mxfp8_sources_to_exclude ]
513517
514518 # TOOD: Remove this and use what CUDA has once we fix all the builds.
519+ # TODO: Add support for other ROCm GPUs
515520 if use_rocm :
516- # Add ROCm GPU architecture check
517- gpu_arch = None
518- if torch .cuda .is_available ():
519- gpu_arch = torch .cuda .get_device_properties (0 ).name
520- if gpu_arch and gpu_arch != "gfx942" :
521- print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
522- print ("Currently only gfx942 is supported. Compiling only for gfx942." )
523521 extra_compile_args ["nvcc" ].append ("--offload-arch=gfx942" )
524522 sources += rocm_sources
525523 else :
0 commit comments