diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index a0c9cf07e..169ff8a57 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -414,3 +414,8 @@ def have_fp8(compute_version): conditions.append(major == 8 and minor >= 9) conditions.append(major >= 9) return any(conditions) + + +def get_nvcc_compiler() -> str: + """Get the path to the nvcc compiler""" + return os.path.join(find_cuda_path(), "bin", "nvcc") diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index c0e05f205..6f57f5a0b 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -3,7 +3,7 @@ from typing import Optional from .utils import is_cuda_target, is_hip_target, is_cpu_target from tilelang import tvm as tvm -from tilelang.contrib.nvcc import get_target_compute_version +from tilelang.contrib.nvcc import get_target_compute_version, get_nvcc_compiler from tvm.target import Target import ctypes import os @@ -44,7 +44,7 @@ def compile_lib(self, timeout: float = None): libpath = src.name.replace(".cu", ".so") command = [ - "nvcc", + get_nvcc_compiler(), "-std=c++17", "-w", # Disable all warning messages "-Xcudafe",