Skip to content

Commit

Permalink
[RUNTIME] remove fixed cu_include_dir (#739)
Browse files Browse the repository at this point in the history
Use environment variable `CUDA_HOME` with default value`/usr/local/cuda` for `cu_include_dir` #731
  • Loading branch information
Shenggan authored Oct 5, 2022
1 parent d3c925d commit 77c752d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,12 @@ def libcuda_dirs():
return [os.path.dirname(loc) for loc in locs]


@functools.lru_cache()
def cuda_home_dirs():
default_dir = "/usr/local/cuda"
return os.getenv("CUDA_HOME", default=default_dir)


@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
Expand All @@ -1119,7 +1125,7 @@ def quiet():

def _build(name, src, srcdir):
cuda_lib_dirs = libcuda_dirs()
cu_include_dir = "/usr/local/cuda/include"
cu_include_dir = os.path.join(cuda_home_dirs(), "include")
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
# try to avoid setuptools if possible
Expand Down

0 comments on commit 77c752d

Please sign in to comment.