From 14b8fdde78d39a0d6bb938c69d47ad3ca2fd94d4 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sat, 27 May 2023 06:24:31 +0000 Subject: [PATCH 1/4] Fix setup script & bfloat16 guard --- csrc/attention/attention_dtypes.h | 3 -- csrc/attention/attention_kernels.cu | 2 -- csrc/attention/dtype_bfloat16.cuh | 44 +++++++++++++++++++++++++++ setup.py | 47 +++++++++++++++++++++++------ 4 files changed, 81 insertions(+), 15 deletions(-) diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index b04ea9a1145e..88b4eddec7fc 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -3,7 +3,4 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" - -#ifdef ENABLE_BF16 #include "dtype_bfloat16.cuh" -#endif // ENABLE_BF16 diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 4a17c9afd479..0854f343ffc0 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -458,10 +458,8 @@ void single_query_cached_kv_attention( // TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); -#ifdef ENABLE_BF16 } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); -#endif } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index e6b1c470c40d..150494122f33 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -78,20 +78,36 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __bfloat1622float2(val); +#endif } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __bfloat162bfloat162(val); +#endif } // Vector addition. inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return a + b; +#endif } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hadd2(a, b); +#endif } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { // Vector multiplication. template<> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hmul(a, b); +#endif } template<> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hmul2(a, b); +#endif } template<> @@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { // Vector fused multiply-add. inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hfma2(a, b, c); +#endif } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hfma2(bf162bf162(a), b, c); +#endif } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { @@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) { } inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst = __float22bfloat162_rn(src); +#endif } inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst.x = __float22bfloat162_rn(src.x); dst.y = __float22bfloat162_rn(src.y); +#endif } inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst.x = __float22bfloat162_rn(src.x); dst.y = __float22bfloat162_rn(src.y); dst.z = __float22bfloat162_rn(src.z); dst.w = __float22bfloat162_rn(src.w); +#endif } } // namespace cacheflow diff --git a/setup.py b/setup.py index 48538ffe9834..95b8f326faf5 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,55 @@ +import subprocess from typing import List +from packaging.version import parse, Version import setuptools import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME - -# Build custom operators. -CXX_FLAGS = ["-g"] +# Compiler flags. +CXX_FLAGS = ["-g", "-O2"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2"] + if not torch.cuda.is_available(): raise RuntimeError( f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. " "CUDA must be available in order to build the package.") -# FIXME(woosuk): Consider the case where the machine has multiple GPUs with -# different compute capabilities. -compute_capability = torch.cuda.get_device_capability() -major, minor = compute_capability -# Enable bfloat16 support if the compute capability is >= 8.0. -if major >= 8: - NVCC_FLAGS.append("-DENABLE_BF16") + +def get_nvcc_cuda_version(cuda_dir: str) -> Version: + """Get the CUDA version from nvcc. + + Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py + """ + nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + return bare_metal_version + + +# Check CUDA version. +nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) +if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + +# Select architectures to compile for based on the NVCC CUDA version. +# NOTE(woosuk): This will increase the build time as we compile for multiple +# architectures, regardless of whether they are used or not. +# CUDA 11.0 supports compute capability up to 8.0. +NVCC_FLAGS += ["-gencode", "arch=compute_70,code=sm_70"] +NVCC_FLAGS += ["-gencode", "arch=compute_75,code=sm_75"] +NVCC_FLAGS += ["-gencode", "arch=compute_80,code=sm_80"] +# Compute capability 8.6 is supported since CUDA 11.1. +if nvcc_cuda_version >= Version("11.1"): + NVCC_FLAGS += ["-gencode", "arch=compute_86,code=sm_86"] +# Compute capability 9.0 is supported since CUDA 11.8. +if nvcc_cuda_version >= Version("11.8"): + NVCC_FLAGS += ["-gencode", "arch=compute_90,code=sm_90"] ext_modules = [] From f139611b72a5e1a5cc68abeec6175011dedee2a5 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sat, 27 May 2023 07:39:46 +0000 Subject: [PATCH 2/4] Optimize build speed --- setup.py | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index 95b8f326faf5..4ca2fef87a3a 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import subprocess -from typing import List +from typing import List, Set from packaging.version import parse, Version import setuptools @@ -28,28 +28,33 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: universal_newlines=True) output = nvcc_output.split() release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - return bare_metal_version - - -# Check CUDA version. + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +# Collect the compute capabilities of all available GPUs. +device_count = torch.cuda.device_count() +compute_capabilities: Set[int] = set() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + compute_capabilities.add(major * 10 + minor) +# If no GPUs are available, add all supported compute capabilities. +if not compute_capabilities: + compute_capabilities = {70, 75, 80, 86, 90} +# Add target compute capabilities to NVCC flags. +for capability in compute_capabilities: + NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] + +# Validate the NVCC CUDA version. nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if nvcc_cuda_version < Version("11.0"): raise RuntimeError("CUDA 11.0 or higher is required to build the package.") - -# Select architectures to compile for based on the NVCC CUDA version. -# NOTE(woosuk): This will increase the build time as we compile for multiple -# architectures, regardless of whether they are used or not. -# CUDA 11.0 supports compute capability up to 8.0. -NVCC_FLAGS += ["-gencode", "arch=compute_70,code=sm_70"] -NVCC_FLAGS += ["-gencode", "arch=compute_75,code=sm_75"] -NVCC_FLAGS += ["-gencode", "arch=compute_80,code=sm_80"] -# Compute capability 8.6 is supported since CUDA 11.1. -if nvcc_cuda_version >= Version("11.1"): - NVCC_FLAGS += ["-gencode", "arch=compute_86,code=sm_86"] -# Compute capability 9.0 is supported since CUDA 11.8. -if nvcc_cuda_version >= Version("11.8"): - NVCC_FLAGS += ["-gencode", "arch=compute_90,code=sm_90"] +if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): + raise RuntimeError( + "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") +if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): + raise RuntimeError( + "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") ext_modules = [] From e320f1a2f45ea382f8f5c691f3f23ab29b2b26e0 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sat, 27 May 2023 07:54:13 +0000 Subject: [PATCH 3/4] Minor --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 4ca2fef87a3a..d4e6f0911948 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: compute_capabilities: Set[int] = set() for i in range(device_count): major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability less than 7.0 are not supported.") compute_capabilities.add(major * 10 + minor) # If no GPUs are available, add all supported compute capabilities. if not compute_capabilities: From 0b4ad1736632fded1a7a3591b0f7cdd5ca79a50a Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Sat, 27 May 2023 07:58:36 +0000 Subject: [PATCH 4/4] Minor fix: --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d4e6f0911948..8526782168c6 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: raise RuntimeError( "GPUs with compute capability less than 7.0 are not supported.") compute_capabilities.add(major * 10 + minor) -# If no GPUs are available, add all supported compute capabilities. +# If no GPU is available, add all supported compute capabilities. if not compute_capabilities: compute_capabilities = {70, 75, 80, 86, 90} # Add target compute capabilities to NVCC flags.