88#include " quantization/utils.cuh"
99#include " quant_conversions.cuh"
1010
11- #ifndef USE_ROCM
12- #include < cub/cub.cuh>
13- #else
14- #include < hipcub/hipcub.hpp>
15- #endif
11+ #include " ../../cub_helpers.h"
1612
1713namespace vllm {
1814
@@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
3632
3733 using BlockReduce = cub::BlockReduce<float , 1024 >;
3834 __shared__ typename BlockReduce::TempStorage reduceStore;
39- ss = BlockReduce (reduceStore).Reduce (ss, cub::Sum {}, blockDim .x );
35+ ss = BlockReduce (reduceStore).Reduce (ss, CubAddOp {}, blockDim .x );
4036
4137 __shared__ float s_rms;
4238 if (threadIdx .x == 0 ) {
@@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
7369 __shared__ typename BlockReduce::TempStorage reduceStore;
7470 block_absmax_val_maybe =
7571 BlockReduce (reduceStore)
76- .Reduce (block_absmax_val_maybe, cub::Max {}, blockDim .x );
72+ .Reduce (block_absmax_val_maybe, CubMaxOp {}, blockDim .x );
7773
7874 __shared__ float s_token_scale;
7975 if (threadIdx .x == 0 ) {
@@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
169165
170166 using BlockReduce = cub::BlockReduce<float , 1024 >;
171167 __shared__ typename BlockReduce::TempStorage reduceStore;
172- ss = BlockReduce (reduceStore).Reduce (ss, cub::Sum {}, blockDim .x );
168+ ss = BlockReduce (reduceStore).Reduce (ss, CubAddOp {}, blockDim .x );
173169
174170 __shared__ float s_rms;
175171 if (threadIdx .x == 0 ) {
@@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
240236 __shared__ typename BlockReduce::TempStorage reduceStore;
241237 block_absmax_val_maybe =
242238 BlockReduce (reduceStore)
243- .Reduce (block_absmax_val_maybe, cub::Max {}, blockDim .x );
239+ .Reduce (block_absmax_val_maybe, CubMaxOp {}, blockDim .x );
244240
245241 __shared__ float s_token_scale;
246242 if (threadIdx .x == 0 ) {
0 commit comments