diff --git a/CMakeLists.txt b/CMakeLists.txt index 8df349ce14fd..652b9bcd5157 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set CUDA include flags for CXX compiler. +# +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") + if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") + endif() +endif() + # # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h new file mode 100644 index 000000000000..470a63a22cab --- /dev/null +++ b/csrc/cub_helpers.h @@ -0,0 +1,17 @@ +#pragma once + +#ifndef USE_ROCM + #include + #if CUB_VERSION >= 200800 + #include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; + #else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; + #endif // CUB_VERSION +#else + #include +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; +#endif // USE_ROCM diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 05be023de0f2..93c73d58390e 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,15 +1,10 @@ #include "type_convert.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -30,7 +25,7 @@ __global__ void rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -85,7 +80,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -126,7 +121,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fd5849d9626..be134089bd6d 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -8,16 +8,11 @@ #include "type_convert.cuh" #include "quantization/fp8/common.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index cd80bfda7dfd..53573ada86ba 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -20,17 +20,7 @@ #include #include #include "../cuda_compat.h" - -#ifndef USE_ROCM - #include - #include - #include - using AddOp = cuda::std::plus; -#else - #include - #include - using AddOp = cub::Sum; -#endif +#include "../cub_helpers.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp()); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); if (threadIdx.x == 0) { diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index d8369108d0bd..bcfde9fbcbbe 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -7,17 +7,10 @@ #include +#include "../../cub_helpers.h" #include "../../dispatch_utils.h" #include "../vectorization_utils.cuh" -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif - static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static constexpr auto i8_min = @@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; - float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x); __shared__ float absmax; if (tid == 0) { absmax = block_max; diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 5fe5dd04bd89..45d6d5082ce4 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,15 +1,10 @@ #include "common.cuh" #include "dispatch_utils.h" +#include "../../cub_helpers.h" #include "../vectorization_utils.cuh" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { template @@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; const float block_max = - BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 3f188872d80d..2d2fd771205c 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -8,11 +8,7 @@ #include "quantization/utils.cuh" #include "quant_conversions.cuh" -#ifndef USE_ROCM - #include -#else - #include -#endif +#include "../../cub_helpers.h" namespace vllm { @@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { @@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) {