From 77c752dc78902b3d6ea8d1a7a93b036a534287e7 Mon Sep 17 00:00:00 2001 From: shenggan Date: Wed, 5 Oct 2022 10:49:57 +0800 Subject: [PATCH] [RUNTIME] remove fixed cu_include_dir (#739) Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731 --- python/triton/compiler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index d14c7698100a..5bdeb4347109 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1107,6 +1107,12 @@ def libcuda_dirs(): return [os.path.dirname(loc) for loc in locs] +@functools.lru_cache() +def cuda_home_dirs(): + default_dir = "/usr/local/cuda" + return os.getenv("CUDA_HOME", default=default_dir) + + @contextlib.contextmanager def quiet(): old_stdout, old_stderr = sys.stdout, sys.stderr @@ -1119,7 +1125,7 @@ def quiet(): def _build(name, src, srcdir): cuda_lib_dirs = libcuda_dirs() - cu_include_dir = "/usr/local/cuda/include" + cu_include_dir = os.path.join(cuda_home_dirs(), "include") suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible