Skip to content

Commit

Permalink
Add version 11.1 in finding CUDA libdevice (apache#7033)
Browse files Browse the repository at this point in the history
* Add CUDA 11.1 libdevice

Maybe we should have a >= check instead.
I also added a fallback to detect the version if version.txt is
missing. Calling nvcc for this has been inspired by what PyTorch
does when compiling extension modules.
  • Loading branch information
t-vi authored and electriclilies committed Feb 18, 2021
1 parent c8122eb commit d2886fc
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,20 @@ def get_cuda_version(cuda_path):
with open(version_file_path) as f:
version_str = f.readline().replace("\n", "").replace("\r", "")
return float(version_str.split(" ")[2][:2])
except:
raise RuntimeError("Cannot read cuda version file")
except FileNotFoundError:
pass

cmd = [os.path.join(cuda_path, "bin", "nvcc"), "--version"]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
release_line = [l for l in out.split("\n") if "release" in l][0]
release_fields = [s.strip() for s in release_line.split(",")]
release_version = [f[1:] for f in release_fields if f.startswith("V")][0]
major_minor = ".".join(release_version.split(".")[:2])
return float(major_minor)
raise RuntimeError("Cannot read cuda version file")


@tvm._ffi.register_func("tvm_callback_libdevice_path")
Expand All @@ -174,7 +186,7 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0):
if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
Expand Down Expand Up @@ -219,6 +231,7 @@ def parse_compute_version(compute_version):
minor = int(split_ver[1])
return major, minor
except (IndexError, ValueError) as err:
# pylint: disable=raise-missing-from
raise RuntimeError("Compute version parsing error: " + str(err))


Expand Down

0 comments on commit d2886fc

Please sign in to comment.