Skip to content

Commit 3ca3e2f

Browse files
hwchen2017gyou2021
authored andcommitted
Update CUDA compute capability to support Blackwell (deepspeedai#7047)
Update CUDA compute capability for cross compile according to wiki page. https://en.wikipedia.org/wiki/CUDA#GPUs_supported --------- Signed-off-by: Hongwei <hongweichen@microsoft.com> Signed-off-by: gyou2021 <ganmei.you@intel.com>
1 parent f4b0f58 commit 3ca3e2f

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

op_builder/builder.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ def installed_cuda_version(name=""):
6161

6262
def get_default_compute_capabilities():
6363
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
64+
# Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported
6465
import torch.utils.cpp_extension
65-
if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
66-
if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
67-
# Special treatment of CUDA 11.0 because compute_86 is not supported.
68-
compute_caps += ";8.0"
69-
else:
66+
if torch.utils.cpp_extension.CUDA_HOME is not None:
67+
if installed_cuda_version()[0] == 11:
68+
if installed_cuda_version()[1] >= 0:
69+
compute_caps += ";8.0"
70+
if installed_cuda_version()[1] >= 1:
71+
compute_caps += ";8.6"
72+
if installed_cuda_version()[1] >= 8:
73+
compute_caps += ";9.0"
74+
elif installed_cuda_version()[0] == 12:
7075
compute_caps += ";8.0;8.6;9.0"
76+
if installed_cuda_version()[1] >= 8:
77+
compute_caps += ";10.0;12.0"
7178
return compute_caps
7279

7380

0 commit comments

Comments
 (0)