Skip to content

Commit bfe9380

Browse files
authored
Apply fixes for CUDA 13 (#24599)
Signed-off-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
1 parent 9fccd04 commit bfe9380

File tree

8 files changed

+47
-56
lines changed

8 files changed

+47
-56
lines changed

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
175175
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
176176
endif()
177177

178+
#
179+
# Set CUDA include flags for CXX compiler.
180+
#
181+
if(VLLM_GPU_LANG STREQUAL "CUDA")
182+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include")
183+
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
184+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl")
185+
endif()
186+
endif()
187+
178188
#
179189
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
180190
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.

csrc/cub_helpers.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#ifndef USE_ROCM
4+
#include <cub/cub.cuh>
5+
#if CUB_VERSION >= 200800
6+
#include <cuda/std/functional>
7+
using CubAddOp = cuda::std::plus<>;
8+
using CubMaxOp = cuda::maximum<>;
9+
#else // if CUB_VERSION < 200800
10+
using CubAddOp = cub::Sum;
11+
using CubMaxOp = cub::Max;
12+
#endif // CUB_VERSION
13+
#else
14+
#include <hipcub/hipcub.hpp>
15+
using CubAddOp = cub::Sum;
16+
using CubMaxOp = cub::Max;
17+
#endif // USE_ROCM

csrc/layernorm_kernels.cu

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
#include "type_convert.cuh"
22
#include "dispatch_utils.h"
3+
#include "cub_helpers.h"
34

45
#include <torch/cuda.h>
56
#include <c10/cuda/CUDAGuard.h>
67

7-
#ifndef USE_ROCM
8-
#include <cub/cub.cuh>
9-
#else
10-
#include <hipcub/hipcub.hpp>
11-
#endif
12-
138
namespace vllm {
149

1510
// TODO(woosuk): Further optimize this kernel.
@@ -30,7 +25,7 @@ __global__ void rms_norm_kernel(
3025

3126
using BlockReduce = cub::BlockReduce<float, 1024>;
3227
__shared__ typename BlockReduce::TempStorage reduceStore;
33-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
28+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
3429

3530
if (threadIdx.x == 0) {
3631
s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -85,7 +80,7 @@ fused_add_rms_norm_kernel(
8580

8681
using BlockReduce = cub::BlockReduce<float, 1024>;
8782
__shared__ typename BlockReduce::TempStorage reduceStore;
88-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
83+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
8984

9085
if (threadIdx.x == 0) {
9186
s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -126,7 +121,7 @@ fused_add_rms_norm_kernel(
126121

127122
using BlockReduce = cub::BlockReduce<float, 1024>;
128123
__shared__ typename BlockReduce::TempStorage reduceStore;
129-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
124+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
130125

131126
if (threadIdx.x == 0) {
132127
s_variance = rsqrtf(variance / hidden_size + epsilon);

csrc/layernorm_quant_kernels.cu

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,11 @@
88
#include "type_convert.cuh"
99
#include "quantization/fp8/common.cuh"
1010
#include "dispatch_utils.h"
11+
#include "cub_helpers.h"
1112

1213
#include <torch/cuda.h>
1314
#include <c10/cuda/CUDAGuard.h>
1415

15-
#ifndef USE_ROCM
16-
#include <cub/cub.cuh>
17-
#else
18-
#include <hipcub/hipcub.hpp>
19-
#endif
20-
2116
namespace vllm {
2217

2318
// TODO(woosuk): Further optimize this kernel.
@@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
3934

4035
using BlockReduce = cub::BlockReduce<float, 1024>;
4136
__shared__ typename BlockReduce::TempStorage reduceStore;
42-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
37+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
4338

4439
if (threadIdx.x == 0) {
4540
s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
10095

10196
using BlockReduce = cub::BlockReduce<float, 1024>;
10297
__shared__ typename BlockReduce::TempStorage reduceStore;
103-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
98+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
10499

105100
if (threadIdx.x == 0) {
106101
s_variance = rsqrtf(variance / hidden_size + epsilon);
@@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
149144

150145
using BlockReduce = cub::BlockReduce<float, 1024>;
151146
__shared__ typename BlockReduce::TempStorage reduceStore;
152-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
147+
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
153148

154149
if (threadIdx.x == 0) {
155150
s_variance = rsqrtf(variance / hidden_size + epsilon);

csrc/moe/topk_softmax_kernels.cu

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,7 @@
2020
#include <ATen/cuda/CUDAContext.h>
2121
#include <c10/cuda/CUDAGuard.h>
2222
#include "../cuda_compat.h"
23-
24-
#ifndef USE_ROCM
25-
#include <cub/util_type.cuh>
26-
#include <cub/cub.cuh>
27-
#include <cuda/std/functional>
28-
using AddOp = cuda::std::plus<float>;
29-
#else
30-
#include <hipcub/util_type.hpp>
31-
#include <hipcub/hipcub.hpp>
32-
using AddOp = cub::Sum;
33-
#endif
23+
#include "../cub_helpers.h"
3424

3525
#define MAX(a, b) ((a) > (b) ? (a) : (b))
3626
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__
7969
threadData = max(static_cast<float>(input[idx]), threadData);
8070
}
8171

82-
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
72+
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
8373
if (threadIdx.x == 0)
8474
{
8575
float_max = maxElem;
@@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__
9484
threadData += exp((static_cast<float>(input[idx]) - float_max));
9585
}
9686

97-
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp());
87+
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
9888

9989
if (threadIdx.x == 0)
10090
{

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,10 @@
77

88
#include <cmath>
99

10+
#include "../../cub_helpers.h"
1011
#include "../../dispatch_utils.h"
1112
#include "../vectorization_utils.cuh"
1213

13-
#ifndef USE_ROCM
14-
#include <cub/cub.cuh>
15-
#include <cub/util_type.cuh>
16-
#else
17-
#include <hipcub/hipcub.hpp>
18-
#include <hipcub/util_type.hpp>
19-
#endif
20-
2114
static inline __device__ int8_t float_to_int8_rn(float x) {
2215
#ifdef USE_ROCM
2316
static constexpr auto i8_min =
@@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
173166
});
174167
using BlockReduce = cub::BlockReduce<float, 256>;
175168
__shared__ typename BlockReduce::TempStorage tmp;
176-
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
169+
float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x);
177170
__shared__ float absmax;
178171
if (tid == 0) {
179172
absmax = block_max;

csrc/quantization/fp8/common.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
#include "common.cuh"
22
#include "dispatch_utils.h"
3+
#include "../../cub_helpers.h"
34
#include "../vectorization_utils.cuh"
45
#include <c10/cuda/CUDAGuard.h>
56
#include <ATen/cuda/Exceptions.h>
67

7-
#ifndef USE_ROCM
8-
#include <cub/cub.cuh>
9-
#else
10-
#include <hipcub/hipcub.hpp>
11-
#endif
12-
138
namespace vllm {
149

1510
template <typename scalar_t, typename fp8_type>
@@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
116111
using BlockReduce = cub::BlockReduce<float, 256>;
117112
__shared__ typename BlockReduce::TempStorage tmp;
118113
const float block_max =
119-
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
114+
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);
120115

121116
__shared__ float token_scale;
122117
if (tid == 0) {

csrc/quantization/fused_kernels/layernorm_utils.cuh

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
#include "quantization/utils.cuh"
99
#include "quant_conversions.cuh"
1010

11-
#ifndef USE_ROCM
12-
#include <cub/cub.cuh>
13-
#else
14-
#include <hipcub/hipcub.hpp>
15-
#endif
11+
#include "../../cub_helpers.h"
1612

1713
namespace vllm {
1814

@@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
3632

3733
using BlockReduce = cub::BlockReduce<float, 1024>;
3834
__shared__ typename BlockReduce::TempStorage reduceStore;
39-
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
35+
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
4036

4137
__shared__ float s_rms;
4238
if (threadIdx.x == 0) {
@@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
7369
__shared__ typename BlockReduce::TempStorage reduceStore;
7470
block_absmax_val_maybe =
7571
BlockReduce(reduceStore)
76-
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
72+
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
7773

7874
__shared__ float s_token_scale;
7975
if (threadIdx.x == 0) {
@@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
169165

170166
using BlockReduce = cub::BlockReduce<float, 1024>;
171167
__shared__ typename BlockReduce::TempStorage reduceStore;
172-
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
168+
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
173169

174170
__shared__ float s_rms;
175171
if (threadIdx.x == 0) {
@@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
240236
__shared__ typename BlockReduce::TempStorage reduceStore;
241237
block_absmax_val_maybe =
242238
BlockReduce(reduceStore)
243-
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
239+
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
244240

245241
__shared__ float s_token_scale;
246242
if (threadIdx.x == 0) {

0 commit comments

Comments
 (0)