diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index b04ea9a1145e2..88b4eddec7fc7 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -3,7 +3,4 @@ #include "attention_generic.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" - -#ifdef ENABLE_BF16 #include "dtype_bfloat16.cuh" -#endif // ENABLE_BF16 diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 4a17c9afd479c..0854f343ffc05 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -458,10 +458,8 @@ void single_query_cached_kv_attention( // TODO(woosuk): Support FP32. if (query.dtype() == at::ScalarType::Half) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); -#ifdef ENABLE_BF16 } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); -#endif } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index e6b1c470c40db..150494122f331 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -78,20 +78,36 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __bfloat1622float2(val); +#endif } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __bfloat162bfloat162(val); +#endif } // Vector addition. inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return a + b; +#endif } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hadd2(a, b); +#endif } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { // Vector multiplication. template<> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hmul(a, b); +#endif } template<> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hmul2(a, b); +#endif } template<> @@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { // Vector fused multiply-add. inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hfma2(a, b, c); +#endif } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else return __hfma2(bf162bf162(a), b, c); +#endif } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { @@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) { } inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst = __float22bfloat162_rn(src); +#endif } inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst.x = __float22bfloat162_rn(src.x); dst.y = __float22bfloat162_rn(src.y); +#endif } inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else dst.x = __float22bfloat162_rn(src.x); dst.y = __float22bfloat162_rn(src.y); dst.z = __float22bfloat162_rn(src.z); dst.w = __float22bfloat162_rn(src.w); +#endif } } // namespace cacheflow diff --git a/setup.py b/setup.py index 48538ffe98348..8526782168c6b 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,63 @@ -from typing import List +import subprocess +from typing import List, Set +from packaging.version import parse, Version import setuptools import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME - -# Build custom operators. -CXX_FLAGS = ["-g"] +# Compiler flags. +CXX_FLAGS = ["-g", "-O2"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2"] + if not torch.cuda.is_available(): raise RuntimeError( f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. " "CUDA must be available in order to build the package.") -# FIXME(woosuk): Consider the case where the machine has multiple GPUs with -# different compute capabilities. -compute_capability = torch.cuda.get_device_capability() -major, minor = compute_capability -# Enable bfloat16 support if the compute capability is >= 8.0. -if major >= 8: - NVCC_FLAGS.append("-DENABLE_BF16") + +def get_nvcc_cuda_version(cuda_dir: str) -> Version: + """Get the CUDA version from nvcc. + + Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py + """ + nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +# Collect the compute capabilities of all available GPUs. +device_count = torch.cuda.device_count() +compute_capabilities: Set[int] = set() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability less than 7.0 are not supported.") + compute_capabilities.add(major * 10 + minor) +# If no GPU is available, add all supported compute capabilities. +if not compute_capabilities: + compute_capabilities = {70, 75, 80, 86, 90} +# Add target compute capabilities to NVCC flags. +for capability in compute_capabilities: + NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] + +# Validate the NVCC CUDA version. +nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) +if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") +if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): + raise RuntimeError( + "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") +if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): + raise RuntimeError( + "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") ext_modules = []