Skip to content

Commit d93c976

Browse files
[Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 749f792 commit d93c976

File tree

4 files changed

+59
-31
lines changed

4 files changed

+59
-31
lines changed

csrc/pos_encoding_kernels.cu

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ inline __device__ void apply_rotary_embedding(
4444
// head_size]
4545
const scalar_t* cache_ptr, const int head_size, const int num_heads,
4646
const int num_kv_heads, const int rot_dim, const int token_idx,
47-
const int64_t query_stride, const int64_t key_stride) {
47+
const int64_t query_stride, const int64_t key_stride,
48+
const int64_t head_stride) {
4849
const int embed_dim = rot_dim / 2;
4950
const scalar_t* cos_ptr = cache_ptr;
5051
const scalar_t* sin_ptr = cache_ptr + embed_dim;
5152

5253
const int nq = num_heads * embed_dim;
5354
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
5455
const int head_idx = i / embed_dim;
55-
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
56+
const int64_t token_head =
57+
token_idx * query_stride + head_idx * head_stride;
5658
const int rot_offset = i % embed_dim;
5759
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
5860
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
6264
const int nk = num_kv_heads * embed_dim;
6365
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
6466
const int head_idx = i / embed_dim;
65-
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
67+
const int64_t token_head =
68+
token_idx * key_stride + head_idx * head_stride;
6669
const int rot_offset = i % embed_dim;
6770
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
6871
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@@ -84,15 +87,16 @@ __global__ void rotary_embedding_kernel(
8487
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
8588
// 2]
8689
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
87-
const int num_heads, const int num_kv_heads, const int head_size) {
90+
const int64_t head_stride, const int num_heads, const int num_kv_heads,
91+
const int head_size) {
8892
// Each thread block is responsible for one token.
8993
const int token_idx = blockIdx.x;
9094
int64_t pos = positions[token_idx];
9195
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
9296

9397
apply_rotary_embedding<scalar_t, IS_NEOX>(
9498
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
95-
token_idx, query_stride, key_stride);
99+
token_idx, query_stride, key_stride, head_stride);
96100
}
97101

98102
template <typename scalar_t, bool IS_NEOX>
@@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
109113
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
110114
// 2]
111115
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
112-
// or [num_tokens]
113116
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
114-
const int num_heads, const int num_kv_heads, const int head_size) {
117+
const int64_t head_stride, const int num_heads, const int num_kv_heads,
118+
const int head_size) {
115119
// Each thread block is responsible for one token.
116120
const int token_idx = blockIdx.x;
117121
int64_t pos = positions[token_idx];
@@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
121125

122126
apply_rotary_embedding<scalar_t, IS_NEOX>(
123127
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
124-
token_idx, query_stride, key_stride);
128+
token_idx, query_stride, key_stride, head_stride);
125129
}
126130

127131
} // namespace vllm
@@ -179,6 +183,12 @@ void rotary_embedding(
179183
int seq_dim_idx = positions_ndim - 1;
180184
int64_t query_stride = query.stride(seq_dim_idx);
181185
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
186+
// Determine head stride: for [*, heads, head_size] use stride of last dim;
187+
// for flat [*, heads*head_size], heads blocks are contiguous of size
188+
// head_size
189+
int query_ndim = query.dim();
190+
int64_t head_stride =
191+
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
182192

183193
dim3 grid(num_tokens);
184194
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -190,14 +200,14 @@ void rotary_embedding(
190200
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
191201
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
192202
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
193-
num_heads, num_kv_heads, head_size);
203+
head_stride, num_heads, num_kv_heads, head_size);
194204
} else {
195205
vllm::rotary_embedding_kernel<scalar_t, false>
196206
<<<grid, block, 0, stream>>>(
197207
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
198208
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
199209
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
200-
key_stride, num_heads, num_kv_heads, head_size);
210+
key_stride, head_stride, num_heads, num_kv_heads, head_size);
201211
}
202212
});
203213
}
@@ -263,6 +273,12 @@ void batched_rotary_embedding(
263273
int seq_dim_idx = positions_ndim - 1;
264274
int64_t query_stride = query.stride(seq_dim_idx);
265275
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
276+
// Determine head stride: for [*, heads, head_size] use stride of last dim;
277+
// for flat [*, heads*head_size], heads blocks are contiguous of size
278+
// head_size
279+
int query_ndim = query.dim();
280+
int64_t head_stride =
281+
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
266282

267283
dim3 grid(num_tokens);
268284
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -276,15 +292,15 @@ void batched_rotary_embedding(
276292
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
277293
cos_sin_cache.data_ptr<scalar_t>(),
278294
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
279-
key_stride, num_heads, num_kv_heads, head_size);
295+
key_stride, head_stride, num_heads, num_kv_heads, head_size);
280296
} else {
281297
vllm::batched_rotary_embedding_kernel<scalar_t, false>
282298
<<<grid, block, 0, stream>>>(
283299
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
284300
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
285301
cos_sin_cache.data_ptr<scalar_t>(),
286302
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
287-
key_stride, num_heads, num_kv_heads, head_size);
303+
key_stride, head_stride, num_heads, num_kv_heads, head_size);
288304
}
289305
});
290306
}

tests/kernels/core/test_pos_encoding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
2929
return (batch_size, seq_len, num_heads * head_size)
3030

3131

32+
# For testing sliced tensors
33+
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
34+
head_size: int) -> tuple[int, ...]:
35+
return (batch_size, seq_len, num_heads, head_size + 64)
36+
37+
3238
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
3339
head_size: int) -> tuple[int, ...]:
3440
return (batch_size, seq_len, num_heads, head_size)
3541

3642

37-
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
43+
TENSORS_SHAPES_FN = [
44+
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
45+
]
3846

3947

4048
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -79,6 +87,10 @@ def test_rotary_embedding(
7987
query = torch.randn(query_shape, dtype=dtype)
8088
key = torch.randn_like(query) if use_key else None
8189

90+
# slice tensor if required, noop otherwise
91+
query = query[..., :head_size]
92+
key = key[..., :head_size] if use_key else None
93+
8294
# NOTE(woosuk): The reference implementation should be executed first
8395
# because the custom kernel is in-place.
8496
ref_query, ref_key = rope.forward_native(positions, query, key)

tests/kernels/core/test_rotary_embedding.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
3838
@pytest.mark.parametrize("head_size", [32, 108])
3939
@pytest.mark.parametrize("seq_len", [11, 1024])
4040
@pytest.mark.parametrize("use_key", [True, False])
41+
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
4142
def test_rotary_embedding_opcheck(dist_init, device, max_position,
4243
is_neox_style, rotary_dim, head_size,
43-
seq_len, use_key):
44+
seq_len, use_key, head_stride_is_contingous):
4445
batch_size = 1
4546
base = 10000
4647
num_heads = 7
@@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
5051
positions = torch.randint(0,
5152
max_position, (batch_size, seq_len),
5253
device=device)
54+
head_stride = head_size + (64 if head_stride_is_contingous else 0)
55+
5356
query = torch.randn(batch_size,
5457
seq_len,
55-
num_heads * head_size,
58+
num_heads,
59+
head_stride,
5660
dtype=torch.float32,
5761
device=device)
5862
key = torch.randn_like(query) if use_key else None
63+
query = query[..., :head_size]
64+
key = key[..., :head_size] if use_key else None
5965

6066
rotary_embedding_opcheck(rot, positions, query, key)
6167
offsets = torch.zeros(batch_size * seq_len,
6268
device=device,
6369
dtype=torch.long)
6470
rotary_embedding_opcheck(rot, positions, query, key, offsets)
71+
72+
# if we have a contiguous head stride, test the alternate
73+
# [..., num_heads * head_dim] shape/layout
74+
if head_stride_is_contingous:
75+
rotary_embedding_opcheck(
76+
rot, positions, query.flatten(start_dim=-2),
77+
key.flatten(start_dim=-2) if use_key else None)

vllm/_custom_ops.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -254,31 +254,18 @@ def rotary_embedding(
254254
cos_sin_cache: torch.Tensor,
255255
is_neox: bool,
256256
) -> None:
257-
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
258-
query_contiguous = query.contiguous()
259-
key_contiguous = key.contiguous() if key is not None else None
260-
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
261-
head_size, cos_sin_cache, is_neox)
262-
query.copy_(query_contiguous)
263-
if key is not None:
264-
key.copy_(key_contiguous)
257+
torch.ops._C.rotary_embedding(positions, query, key, head_size,
258+
cos_sin_cache, is_neox)
265259

266260

267261
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
268262
key: Optional[torch.Tensor], head_size: int,
269263
cos_sin_cache: torch.Tensor, is_neox: bool,
270264
rot_dim: int,
271265
cos_sin_cache_offsets: torch.Tensor) -> None:
272-
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
273-
query_contiguous = query.contiguous()
274-
key_contiguous = key.contiguous() if key is not None else None
275-
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
276-
key_contiguous, head_size,
266+
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
277267
cos_sin_cache, is_neox, rot_dim,
278268
cos_sin_cache_offsets)
279-
query.copy_(query_contiguous)
280-
if key is not None:
281-
key.copy_(key_contiguous)
282269

283270

284271
# layer norm ops

0 commit comments

Comments
 (0)