Skip to content

Commit

Permalink
Support H100 for other CUDA extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Mar 15, 2023
1 parent 1b18f1b commit dc08ea1
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 88 deletions.
53 changes: 31 additions & 22 deletions csrc/ft_attention/setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess

import sys
import warnings
import os
from packaging.version import parse, Version

from setuptools import setup, find_packages
import subprocess

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_version


def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)

print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")

if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
Expand All @@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

Expand All @@ -72,15 +72,18 @@ def append_nvcc_threads(nvcc_extra_args):
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"


print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
Expand All @@ -98,10 +101,16 @@ def append_nvcc_threads(nvcc_extra_args):
raise_if_cuda_home_none("--ft_attention")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_70,code=sm_70")
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("ft_attention is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

ext_modules.append(
CUDAExtension(
Expand Down
11 changes: 5 additions & 6 deletions csrc/fused_dense_lib/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import subprocess
from packaging.version import parse, Version

import torch
from setuptools import setup
Expand All @@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_version


def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

Expand Down
45 changes: 26 additions & 19 deletions csrc/layer_norm/setup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
from packaging.version import parse, Version

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess

import sys
import warnings
import os

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_version


def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)

print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")

if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
Expand All @@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

Expand All @@ -72,15 +70,18 @@ def append_nvcc_threads(nvcc_extra_args):
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"


print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
Expand All @@ -98,10 +99,16 @@ def append_nvcc_threads(nvcc_extra_args):
raise_if_cuda_home_none("--fast_layer_norm")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

ext_modules.append(
CUDAExtension(
Expand Down
45 changes: 26 additions & 19 deletions csrc/rotary/setup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
from packaging.version import parse, Version

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess

import sys
import warnings
import os

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
bare_metal_version = parse(output[release_idx].split(",")[0])

return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_version


def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)

print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")

if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
Expand All @@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args

Expand All @@ -72,15 +70,18 @@ def append_nvcc_threads(nvcc_extra_args):
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"


print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
Expand All @@ -91,10 +92,16 @@ def append_nvcc_threads(nvcc_extra_args):
raise_if_cuda_home_none("rotary_emb")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("rotary_emb is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

ext_modules.append(
CUDAExtension(
Expand Down
Loading

0 comments on commit dc08ea1

Please sign in to comment.