diff --git a/setup.py b/setup.py index 229e18eec..eb56f863a 100644 --- a/setup.py +++ b/setup.py @@ -46,9 +46,11 @@ def read_version(file_path="version.txt"): CUDAExtension, BuildExtension, CUDA_HOME, + ROCM_HOME, IS_WINDOWS ) +IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) def get_extensions(): debug_mode = os.getenv('DEBUG', '0') == '1' @@ -57,11 +59,11 @@ def get_extensions(): if not torch.cuda.is_available(): print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions") - if CUDA_HOME is None and torch.cuda.is_available(): - print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") + if CUDA_HOME is None or not IS_ROCM and torch.cuda.is_available(): + print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions") print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit") - use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None) extension = CUDAExtension if use_cuda else CppExtension if not IS_WINDOWS: @@ -71,15 +73,14 @@ def get_extensions(): "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ] } + if use_cuda and not IS_ROCM: + extra_compile_args["nvcc"] = ["-O3" if not debug_mode else "-O0", "-t=0",] if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") + if "nvcc" in extra_compile_args: + extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: @@ -107,17 +108,35 @@ def get_extensions(): extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)) - if use_cuda: + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "fp6_llm") + hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)) + + if not IS_ROCM and use_cuda: sources += cuda_sources - ext_modules = [ - extension( - "torchao._C", - sources, - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, - ) - ] + # TOOD: Remove this and use what CUDA has once we fix all the builds. + if IS_ROCM and use_cuda: + sources += hip_sources + + ## TODO: remove this condition and use what we have in CUDA once we fix the individual builds. + if not IS_ROCM: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + else: + ext_modules = [ + extension( + "torchao._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] return ext_modules