11#include " common.cuh"
22#include " dispatch_utils.h"
3-
3+ # include " ../vectorization_utils.cuh "
44#include < c10/cuda/CUDAGuard.h>
5+ #include < ATen/cuda/Exceptions.h>
56
67#ifndef USE_ROCM
78 #include < cub/cub.cuh>
1213namespace vllm {
1314
1415template <typename scalar_t , typename fp8_type>
15- __global__ void scaled_fp8_quant_kernel (fp8_type* __restrict__ out,
16- const scalar_t * __restrict__ input,
17- const float * __restrict__ scale,
18- int64_t num_elems) {
19- int tid = blockDim .x * blockIdx .x + threadIdx .x ;
20-
21- // Invert the scale so that we can use multiplications to avoid expensive
22- // division.
23- const float inverted_scale = 1 .0f / (*scale);
24- scaled_fp8_conversion_vec<scalar_t , true >(
25- out, input, inverted_scale, num_elems, tid, blockDim .x * gridDim .x );
16+ __global__ void scaled_fp8_quant_kernel_strided (
17+ fp8_type* __restrict__ out, const scalar_t * __restrict__ input,
18+ const float * __restrict__ scale, int hidden_size, int64_t in_row_stride,
19+ int64_t out_row_stride) {
20+ const int64_t token_idx = blockIdx .x ; // one token per block
21+ const int tid = threadIdx .x ;
22+
23+ const scalar_t * token_in = input + token_idx * in_row_stride;
24+ fp8_type* token_out = out + token_idx * out_row_stride;
25+
26+ const float inv_scale = 1 .0f / (*scale);
27+
28+ vectorize_with_alignment<16 >(
29+ token_in, token_out, hidden_size, tid, blockDim .x ,
30+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
31+ dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
32+ inv_scale);
33+ });
2634}
2735
2836template <typename scalar_t , typename fp8_type>
29- __global__ void dynamic_per_token_scaled_fp8_quant_kernel (
30- fp8_type* __restrict__ out, float * __restrict__ scale,
31- scalar_t const * __restrict__ input, float const * __restrict__ scale_ub,
32- const int hidden_size) {
33- int const tid = threadIdx .x ;
34- int const token_idx = blockIdx .x ;
37+ __global__ void segmented_max_reduction_strided (
38+ float * __restrict__ scale, const scalar_t * __restrict__ input,
39+ int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
40+ __shared__ float cache[256 ];
41+ const int tid = threadIdx .x ;
42+ int64_t token_idx = blockIdx .x ;
43+
44+ // one block per token. Guard in case gridDim.x > num_tokens.
45+ if (token_idx >= num_tokens) {
46+ return ;
47+ }
3548
36- // Use int64 to avoid overflowing an int32 when calculating this offset
37- int64_t offset = static_cast < int64_t >(token_idx) * hidden_size;
38- scalar_t const * __restrict__ token_input = &input[offset];
39- fp8_type* __restrict__ token_output = &out[offset] ;
40-
41- // For vectorization, token_input and token_output pointers need to be
42- // aligned at 32-byte and 16-byte addresses respectively.
43- bool const can_vectorize = hidden_size % 16 == 0 ;
44-
45- float absmax_val = 0 . 0f ;
46- if (can_vectorize) {
47- absmax_val = thread_max_vec (token_input, hidden_size, tid, blockDim . x );
48- } else {
49- for (int i = tid; i < hidden_size; i += blockDim . x ) {
50- float const x = static_cast < float >(token_input[i]);
51- absmax_val = fmaxf (absmax_val, fabsf (x) );
49+ const scalar_t * row_ptr = input + token_idx * in_row_stride;
50+
51+ // each thread scans elements of the row in a strided fashion.
52+ float thread_max = 0 . 0f ;
53+ for ( int e = tid; e < hidden_size; e += blockDim . x ) {
54+ float v = fabsf ( static_cast < float >(row_ptr[e]));
55+ thread_max = fmaxf (thread_max, v);
56+ }
57+
58+ cache[tid] = thread_max ;
59+ __syncthreads ();
60+
61+ // parallel reduction to find row max.
62+ for (int offset = blockDim . x / 2 ; offset > 0 ; offset >>= 1 ) {
63+ if (tid < offset) {
64+ cache[tid] = fmaxf (cache[tid], cache[tid + offset] );
5265 }
66+ __syncthreads ();
5367 }
5468
69+ // thread 0 updates global scale (per-tensor) atomically.
70+ if (tid == 0 ) {
71+ atomicMaxFloat (scale, cache[0 ] / quant_type_max_v<fp8_type>);
72+ }
73+ }
74+
75+ template <typename scalar_t , typename fp8_type>
76+ __global__ void scaled_fp8_quant_kernel_strided_dynamic (
77+ fp8_type* __restrict__ out, const scalar_t * __restrict__ input,
78+ const float * __restrict__ scale, int hidden_size, int64_t in_row_stride,
79+ int64_t out_row_stride) {
80+ const int64_t token_idx = blockIdx .x ;
81+ const int tid = threadIdx .x ;
82+
83+ const scalar_t * token_in = input + token_idx * in_row_stride;
84+ fp8_type* token_out = out + token_idx * out_row_stride;
85+
86+ const float reciprocal_scale = 1 .0f / (*scale);
87+ vectorize_with_alignment<16 >(
88+ token_in, token_out, hidden_size, tid, blockDim .x ,
89+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
90+ dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
91+ reciprocal_scale);
92+ });
93+ }
94+
95+ template <typename scalar_t , typename fp8_type>
96+ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided (
97+ fp8_type* __restrict__ out, float * __restrict__ scale,
98+ const scalar_t * __restrict__ input, const float * __restrict__ scale_ub,
99+ int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
100+ const int64_t token_idx = blockIdx .x ;
101+ const int tid = threadIdx .x ;
102+
103+ // Use int64 to avoid overflowing an int32 when calculating this offset
104+ int64_t in_offset = static_cast <int64_t >(token_idx) * in_row_stride;
105+ int64_t out_offset = static_cast <int64_t >(token_idx) * out_row_stride;
106+ const scalar_t * token_in = input + in_offset;
107+ fp8_type* token_out = out + out_offset;
108+
109+ // 1) per-token absmax
110+ float absmax_val = 0 .f ;
111+ vectorize_read_with_alignment<16 >(
112+ token_in, hidden_size, tid, blockDim .x , [&] __device__ (scalar_t v) {
113+ absmax_val = fmaxf (absmax_val, fabsf (static_cast <float >(v)));
114+ });
115+
55116 using BlockReduce = cub::BlockReduce<float , 256 >;
56- __shared__ typename BlockReduce::TempStorage reduceStorage;
57- float const block_absmax_val_maybe =
58- BlockReduce (reduceStorage).Reduce (absmax_val, cub::Max{}, blockDim .x );
117+ __shared__ typename BlockReduce::TempStorage tmp;
118+ const float block_max =
119+ BlockReduce (tmp).Reduce (absmax_val, cub::Max{}, blockDim .x );
120+
59121 __shared__ float token_scale;
60122 if (tid == 0 ) {
61- if (scale_ub) {
62- token_scale = fminf (block_absmax_val_maybe, *scale_ub);
63- } else {
64- token_scale = block_absmax_val_maybe;
65- }
66- // token scale computation
123+ token_scale = scale_ub ? fminf (block_max, *scale_ub) : block_max;
67124 token_scale = fmaxf (token_scale / quant_type_max_v<fp8_type>,
68125 min_scaling_factor<fp8_type>::val ());
69126 scale[token_idx] = token_scale;
70127 }
71128 __syncthreads ();
72129
73- // Note that we don't use inverted scales so we can match FBGemm impl.
74- if (can_vectorize) {
75- scaled_fp8_conversion_vec<scalar_t , false >(
76- token_output, token_input, token_scale, hidden_size, tid, blockDim .x );
77- } else {
78- for (int i = tid; i < hidden_size; i += blockDim .x ) {
79- token_output[i] = scaled_fp8_conversion<false , fp8_type>(
80- static_cast <float >(token_input[i]), token_scale);
81- }
82- }
130+ // 2) quantize
131+ vectorize_with_alignment<16 >(
132+ token_in, token_out, hidden_size, tid, blockDim .x ,
133+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
134+ dst = scaled_fp8_conversion<false , fp8_type>(static_cast <float >(src),
135+ token_scale);
136+ });
83137}
84138
85139} // namespace vllm
@@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
88142 torch::Tensor const & input, // [..., d]
89143 torch::Tensor const & scale) // [1]
90144{
91- TORCH_CHECK (input.is_contiguous ());
92- TORCH_CHECK (out.is_contiguous ());
93- int const block_size = 256 ;
94- int const num_tokens = input.numel () / input.size (-1 );
95- int const num_elems = input.numel ();
96- dim3 const grid (num_tokens);
97- dim3 const block (block_size);
145+ TORCH_CHECK (input.stride (-1 ) == 1 ,
146+ " last dimension of input must be contiguous" );
147+ TORCH_CHECK (out.stride (-1 ) == 1 ,
148+ " last dimension of output must be contiguous" );
149+
150+ const int hidden_size = input.size (-1 );
151+ const int num_tokens = input.numel () / hidden_size;
152+ const int block_size = 256 ;
153+ dim3 grid (num_tokens);
154+ dim3 block (block_size);
155+
156+ const int64_t in_row_stride = input.stride (-2 );
157+ const int64_t out_row_stride = out.stride (-2 );
158+
98159 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
99160 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
100161 VLLM_DISPATCH_FLOATING_TYPES (
101162 input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
102163 VLLM_DISPATCH_FP8_TYPES (
103164 out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
104- vllm::scaled_fp8_quant_kernel <scalar_t , fp8_t >
165+ vllm::scaled_fp8_quant_kernel_strided <scalar_t , fp8_t >
105166 <<<grid, block, 0 , stream>>> (
106167 out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
107- scale.data_ptr <float >(), num_elems);
168+ scale.data_ptr <float >(), hidden_size, in_row_stride,
169+ out_row_stride);
108170 });
109171 });
110172}
@@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
113175 torch::Tensor const & input, // [..., d]
114176 torch::Tensor& scale) // [1]
115177{
116- TORCH_CHECK (input.is_contiguous ());
117- TORCH_CHECK (out.is_contiguous ());
118- int const block_size = 256 ;
119- int const num_tokens = input.numel () / input.size (-1 );
120- int const num_elems = input.numel ();
121- dim3 const grid (num_tokens);
122- dim3 const block (block_size);
178+ TORCH_CHECK (input.stride (-1 ) == 1 ,
179+ " last dimension of input must be contiguous" );
180+ TORCH_CHECK (out.stride (-1 ) == 1 ,
181+ " last dimension of output must be contiguous" );
182+
183+ const int hidden_size = input.size (-1 );
184+ const int num_tokens = input.numel () / hidden_size;
185+ const int block_size = 256 ;
186+ dim3 grid (num_tokens);
187+ dim3 block (block_size);
188+
189+ const int64_t in_row_stride = input.stride (-2 );
190+ const int64_t out_row_stride = out.stride (-2 );
191+
123192 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
124193 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
194+
195+ // scale tensor should be initialised to <=0 before reduction
196+ AT_CUDA_CHECK (
197+ cudaMemsetAsync (scale.data_ptr <float >(), 0 , sizeof (float ), stream));
198+
125199 VLLM_DISPATCH_FLOATING_TYPES (
126200 input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
127201 VLLM_DISPATCH_FP8_TYPES (
128202 out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
129- vllm::segmented_max_reduction<scalar_t , fp8_t >
130- <<<grid, block, 0 , stream>>> (scale.data_ptr <float >(),
131- input.data_ptr <scalar_t >(),
132- num_elems);
133- vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >
203+ vllm::segmented_max_reduction_strided<scalar_t , fp8_t >
204+ <<<grid, block, 0 , stream>>> (
205+ scale.data_ptr <float >(), input.data_ptr <scalar_t >(),
206+ hidden_size, in_row_stride,
207+ static_cast <int64_t >(num_tokens));
208+
209+ vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t , fp8_t >
134210 <<<grid, block, 0 , stream>>> (
135211 out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
136- scale.data_ptr <float >(), num_elems);
212+ scale.data_ptr <float >(), hidden_size, in_row_stride,
213+ out_row_stride);
137214 });
138215 });
139216}
@@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant(
142219 torch::Tensor& out, // [..., d]
143220 torch::Tensor const & input, // [..., d]
144221 torch::Tensor& scales, std::optional<at::Tensor> const & scale_ub) {
145- TORCH_CHECK (input.is_contiguous ());
146- TORCH_CHECK (out.is_contiguous ());
222+ TORCH_CHECK (input.stride (-1 ) == 1 ,
223+ " last dimension of input must be contiguous" );
224+ TORCH_CHECK (out.stride (-1 ) == 1 ,
225+ " last dimension of output must be contiguous" );
147226
148- int const hidden_size = input.size (-1 );
149- int const num_tokens = input.numel () / hidden_size;
150- int const block_size = 256 ;
151- dim3 const grid (num_tokens);
152- dim3 const block (std::min (hidden_size, block_size));
227+ const int hidden_size = input.size (-1 );
228+ const int num_tokens = input.numel () / hidden_size;
229+ const int block_size = 256 ;
230+ dim3 grid (num_tokens);
231+ dim3 block (std::min (hidden_size, block_size));
232+
233+ const int64_t in_row_stride = input.stride (-2 );
234+ const int64_t out_row_stride = out.stride (-2 );
153235
154236 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
155237 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant(
159241 VLLM_DISPATCH_FP8_TYPES (
160242 out.scalar_type (),
161243 " dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" , [&] {
162- vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t , fp8_t >
163- <<<grid, block, 0 , stream>>> (
164- out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
165- input.data_ptr <scalar_t >(),
166- scale_ub.has_value () ? scale_ub->data_ptr <float >()
167- : nullptr ,
168- hidden_size);
244+ vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
245+ scalar_t , fp8_t ><<<grid, block, 0 , stream>>> (
246+ out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
247+ input.data_ptr <scalar_t >(),
248+ scale_ub.has_value () ? scale_ub->data_ptr <float >() : nullptr ,
249+ hidden_size, in_row_stride, out_row_stride);
169250 });
170251 });
171252}
0 commit comments