@@ -44,15 +44,17 @@ inline __device__ void apply_rotary_embedding(
4444 // head_size]
4545 const scalar_t * cache_ptr, const int head_size, const int num_heads,
4646 const int num_kv_heads, const int rot_dim, const int token_idx,
47- const int64_t query_stride, const int64_t key_stride) {
47+ const int64_t query_stride, const int64_t key_stride,
48+ const int64_t head_stride) {
4849 const int embed_dim = rot_dim / 2 ;
4950 const scalar_t * cos_ptr = cache_ptr;
5051 const scalar_t * sin_ptr = cache_ptr + embed_dim;
5152
5253 const int nq = num_heads * embed_dim;
5354 for (int i = threadIdx .x ; i < nq; i += blockDim .x ) {
5455 const int head_idx = i / embed_dim;
55- const int64_t token_head = token_idx * query_stride + head_idx * head_size;
56+ const int64_t token_head =
57+ token_idx * query_stride + head_idx * head_stride;
5658 const int rot_offset = i % embed_dim;
5759 apply_token_rotary_embedding<scalar_t , IS_NEOX>(
5860 query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
6264 const int nk = num_kv_heads * embed_dim;
6365 for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
6466 const int head_idx = i / embed_dim;
65- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
67+ const int64_t token_head =
68+ token_idx * key_stride + head_idx * head_stride;
6669 const int rot_offset = i % embed_dim;
6770 apply_token_rotary_embedding<scalar_t , IS_NEOX>(
6871 key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -84,15 +87,16 @@ __global__ void rotary_embedding_kernel(
8487 const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
8588 // 2]
8689 const int rot_dim, const int64_t query_stride, const int64_t key_stride,
87- const int num_heads, const int num_kv_heads, const int head_size) {
90+ const int64_t head_stride, const int num_heads, const int num_kv_heads,
91+ const int head_size) {
8892 // Each thread block is responsible for one token.
8993 const int token_idx = blockIdx .x ;
9094 int64_t pos = positions[token_idx];
9195 const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
9296
9397 apply_rotary_embedding<scalar_t , IS_NEOX>(
9498 query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
95- token_idx, query_stride, key_stride);
99+ token_idx, query_stride, key_stride, head_stride );
96100}
97101
98102template <typename scalar_t , bool IS_NEOX>
@@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
109113 const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
110114 // 2]
111115 const int64_t * __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
112- // or [num_tokens]
113116 const int rot_dim, const int64_t query_stride, const int64_t key_stride,
114- const int num_heads, const int num_kv_heads, const int head_size) {
117+ const int64_t head_stride, const int num_heads, const int num_kv_heads,
118+ const int head_size) {
115119 // Each thread block is responsible for one token.
116120 const int token_idx = blockIdx .x ;
117121 int64_t pos = positions[token_idx];
@@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
121125
122126 apply_rotary_embedding<scalar_t , IS_NEOX>(
123127 query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
124- token_idx, query_stride, key_stride);
128+ token_idx, query_stride, key_stride, head_stride );
125129}
126130
127131} // namespace vllm
@@ -179,6 +183,12 @@ void rotary_embedding(
179183 int seq_dim_idx = positions_ndim - 1 ;
180184 int64_t query_stride = query.stride (seq_dim_idx);
181185 int64_t key_stride = key.has_value () ? key->stride (seq_dim_idx) : 0 ;
186+ // Determine head stride: for [*, heads, head_size] use stride of last dim;
187+ // for flat [*, heads*head_size], heads blocks are contiguous of size
188+ // head_size
189+ int query_ndim = query.dim ();
190+ int64_t head_stride =
191+ (query_ndim == positions_ndim + 2 ) ? query.stride (-2 ) : head_size;
182192
183193 dim3 grid (num_tokens);
184194 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -190,14 +200,14 @@ void rotary_embedding(
190200 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
191201 key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
192202 cos_sin_cache.data_ptr <scalar_t >(), rot_dim, query_stride, key_stride,
193- num_heads, num_kv_heads, head_size);
203+ head_stride, num_heads, num_kv_heads, head_size);
194204 } else {
195205 vllm::rotary_embedding_kernel<scalar_t , false >
196206 <<<grid, block, 0 , stream>>> (
197207 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
198208 key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
199209 cos_sin_cache.data_ptr <scalar_t >(), rot_dim, query_stride,
200- key_stride, num_heads, num_kv_heads, head_size);
210+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
201211 }
202212 });
203213}
@@ -263,6 +273,12 @@ void batched_rotary_embedding(
263273 int seq_dim_idx = positions_ndim - 1 ;
264274 int64_t query_stride = query.stride (seq_dim_idx);
265275 int64_t key_stride = key.has_value () ? key->stride (seq_dim_idx) : 0 ;
276+ // Determine head stride: for [*, heads, head_size] use stride of last dim;
277+ // for flat [*, heads*head_size], heads blocks are contiguous of size
278+ // head_size
279+ int query_ndim = query.dim ();
280+ int64_t head_stride =
281+ (query_ndim == positions_ndim + 2 ) ? query.stride (-2 ) : head_size;
266282
267283 dim3 grid (num_tokens);
268284 dim3 block (std::min<int64_t >(num_heads * rot_dim / 2 , 512 ));
@@ -276,15 +292,15 @@ void batched_rotary_embedding(
276292 key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
277293 cos_sin_cache.data_ptr <scalar_t >(),
278294 cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
279- key_stride, num_heads, num_kv_heads, head_size);
295+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
280296 } else {
281297 vllm::batched_rotary_embedding_kernel<scalar_t , false >
282298 <<<grid, block, 0 , stream>>> (
283299 positions.data_ptr <int64_t >(), query.data_ptr <scalar_t >(),
284300 key.has_value () ? key->data_ptr <scalar_t >() : nullptr ,
285301 cos_sin_cache.data_ptr <scalar_t >(),
286302 cos_sin_cache_offsets.data_ptr <int64_t >(), rot_dim, query_stride,
287- key_stride, num_heads, num_kv_heads, head_size);
303+ key_stride, head_stride, num_heads, num_kv_heads, head_size);
288304 }
289305 });
290306}
0 commit comments