@@ -15,15 +15,16 @@ namespace vllm {
1515// TODO(woosuk): Further optimize this kernel.
1616template <typename scalar_t >
1717__global__ void rms_norm_kernel (
18- scalar_t * __restrict__ out, // [..., hidden_size]
19- const scalar_t * __restrict__ input, // [..., hidden_size]
18+ scalar_t * __restrict__ out, // [..., hidden_size]
19+ const scalar_t * __restrict__ input, // [..., hidden_size]
20+ const int64_t input_stride,
2021 const scalar_t * __restrict__ weight, // [hidden_size]
2122 const float epsilon, const int num_tokens, const int hidden_size) {
2223 __shared__ float s_variance;
2324 float variance = 0 .0f ;
2425
2526 for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
26- const float x = (float )input[blockIdx .x * hidden_size + idx];
27+ const float x = (float )input[blockIdx .x * input_stride + idx];
2728 variance += x * x;
2829 }
2930
@@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
3738 __syncthreads ();
3839
3940 for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
40- float x = (float )input[blockIdx .x * hidden_size + idx];
41+ float x = (float )input[blockIdx .x * input_stride + idx];
4142 out[blockIdx .x * hidden_size + idx] =
4243 ((scalar_t )(x * s_variance)) * weight[idx];
4344 }
@@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
5051template <typename scalar_t , int width>
5152__global__ std::enable_if_t <(width > 0 ) && _typeConvert<scalar_t >::exists>
5253fused_add_rms_norm_kernel (
53- scalar_t * __restrict__ input, // [..., hidden_size]
54+ scalar_t * __restrict__ input, // [..., hidden_size]
55+ const int64_t input_stride,
5456 scalar_t * __restrict__ residual, // [..., hidden_size]
5557 const scalar_t * __restrict__ weight, // [hidden_size]
5658 const float epsilon, const int num_tokens, const int hidden_size) {
@@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
5961 static_assert (sizeof (_f16Vec<scalar_t , width>) == sizeof (scalar_t ) * width);
6062
6163 const int vec_hidden_size = hidden_size / width;
64+ const int64_t vec_input_stride = input_stride / width;
6265 __shared__ float s_variance;
6366 float variance = 0 .0f ;
6467 /* These and the argument pointers are all declared `restrict` as they are
@@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
7376
7477 for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
7578 int id = blockIdx .x * vec_hidden_size + idx;
76- _f16Vec<scalar_t , width> temp = input_v[id];
79+ int64_t strided_id = blockIdx .x * vec_input_stride + idx;
80+ _f16Vec<scalar_t , width> temp = input_v[strided_id];
7781 temp += residual_v[id];
7882 variance += temp.sum_squares ();
7983 residual_v[id] = temp;
@@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
9094
9195 for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
9296 int id = blockIdx .x * vec_hidden_size + idx;
97+ int64_t strided_id = blockIdx .x * vec_input_stride + idx;
9398 _f16Vec<scalar_t , width> temp = residual_v[id];
9499 temp *= s_variance;
95100 temp *= weight_v[idx];
96- input_v[id ] = temp;
101+ input_v[strided_id ] = temp;
97102 }
98103}
99104
@@ -103,15 +108,16 @@ fused_add_rms_norm_kernel(
103108template <typename scalar_t , int width>
104109__global__ std::enable_if_t <(width == 0 ) || !_typeConvert<scalar_t >::exists>
105110fused_add_rms_norm_kernel (
106- scalar_t * __restrict__ input, // [..., hidden_size]
111+ scalar_t * __restrict__ input, // [..., hidden_size]
112+ const int64_t input_stride,
107113 scalar_t * __restrict__ residual, // [..., hidden_size]
108114 const scalar_t * __restrict__ weight, // [hidden_size]
109115 const float epsilon, const int num_tokens, const int hidden_size) {
110116 __shared__ float s_variance;
111117 float variance = 0 .0f ;
112118
113119 for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
114- scalar_t z = input[blockIdx .x * hidden_size + idx];
120+ scalar_t z = input[blockIdx .x * input_stride + idx];
115121 z += residual[blockIdx .x * hidden_size + idx];
116122 float x = (float )z;
117123 variance += x * x;
@@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
129135
130136 for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
131137 float x = (float )residual[blockIdx .x * hidden_size + idx];
132- input[blockIdx .x * hidden_size + idx] =
138+ input[blockIdx .x * input_stride + idx] =
133139 ((scalar_t )(x * s_variance)) * weight[idx];
134140 }
135141}
@@ -141,38 +147,42 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
141147 torch::Tensor& weight, // [hidden_size]
142148 double epsilon) {
143149 TORCH_CHECK (out.is_contiguous ());
144- TORCH_CHECK (input.is_contiguous () );
150+ TORCH_CHECK (input.stride (- 1 ) == 1 );
145151 TORCH_CHECK (weight.is_contiguous ());
146152
147153 int hidden_size = input.size (-1 );
148154 int num_tokens = input.numel () / hidden_size;
155+ int64_t input_stride = input.stride (-2 );
149156
150157 dim3 grid (num_tokens);
151158 dim3 block (std::min (hidden_size, 1024 ));
152159 const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
153160 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
154161 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
155162 vllm::rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
156- out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(),
163+ out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), input_stride,
157164 weight.data_ptr <scalar_t >(), epsilon, num_tokens, hidden_size);
158165 });
159166}
160167
161- #define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
162- VLLM_DISPATCH_FLOATING_TYPES ( \
163- input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
164- vllm::fused_add_rms_norm_kernel<scalar_t , width> \
165- <<<grid, block, 0 , stream>>> (input. data_ptr < scalar_t >(), \
166- residual. data_ptr < scalar_t >(), \
167- weight .data_ptr <scalar_t >(), epsilon , \
168- num_tokens, hidden_size); \
168+ #define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
169+ VLLM_DISPATCH_FLOATING_TYPES ( \
170+ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
171+ vllm::fused_add_rms_norm_kernel<scalar_t , width> \
172+ <<<grid, block, 0 , stream>>> ( \
173+ input. data_ptr < scalar_t >(), input_stride, \
174+ residual .data_ptr <scalar_t >(), weight. data_ptr < scalar_t >() , \
175+ epsilon, num_tokens, hidden_size); \
169176 });
170177
171178void fused_add_rms_norm (torch::Tensor& input, // [..., hidden_size]
172179 torch::Tensor& residual, // [..., hidden_size]
173180 torch::Tensor& weight, // [hidden_size]
174181 double epsilon) {
182+ TORCH_CHECK (residual.is_contiguous ());
183+ TORCH_CHECK (weight.is_contiguous ());
175184 int hidden_size = input.size (-1 );
185+ int64_t input_stride = input.stride (-2 );
176186 int num_tokens = input.numel () / hidden_size;
177187
178188 dim3 grid (num_tokens);
@@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
194204 auto inp_ptr = reinterpret_cast <std::uintptr_t >(input.data_ptr ());
195205 auto res_ptr = reinterpret_cast <std::uintptr_t >(residual.data_ptr ());
196206 auto wt_ptr = reinterpret_cast <std::uintptr_t >(weight.data_ptr ());
197- bool ptrs_are_aligned =
198- inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0 ;
199- if (ptrs_are_aligned && hidden_size % 8 == 0 ) {
207+ constexpr int vector_width = 8 ;
208+ constexpr int req_alignment_bytes =
209+ vector_width * 2 ; // vector_width * sizeof(bfloat16 or float16) (float32
210+ // falls back to non-vectorized version anyway)
211+ bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
212+ res_ptr % req_alignment_bytes == 0 &&
213+ wt_ptr % req_alignment_bytes == 0 ;
214+ bool offsets_are_multiple_of_vector_width =
215+ hidden_size % vector_width == 0 && input_stride % vector_width == 0 ;
216+ if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
200217 LAUNCH_FUSED_ADD_RMS_NORM (8 );
201218 } else {
202219 LAUNCH_FUSED_ADD_RMS_NORM (0 );
0 commit comments