Skip to content

Commit e6f92bc

Browse files
WoosukKwonxuebwang-amd
authored andcommitted
[Chore] Remove unused batched RoPE op & kernel (vllm-project#24789)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent edda199 commit e6f92bc

File tree

8 files changed

+16
-348
lines changed

8 files changed

+16
-348
lines changed

csrc/ops.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
122122
std::optional<torch::Tensor> key, int64_t head_size,
123123
torch::Tensor& cos_sin_cache, bool is_neox);
124124

125-
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
126-
std::optional<torch::Tensor> key,
127-
int64_t head_size, torch::Tensor& cos_sin_cache,
128-
bool is_neox, int64_t rot_dim,
129-
torch::Tensor& cos_sin_cache_offsets);
130-
131125
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
132126

133127
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,

csrc/pos_encoding_kernels.cu

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
9999
token_idx, query_stride, key_stride, head_stride);
100100
}
101101

102-
template <typename scalar_t, bool IS_NEOX>
103-
__global__ void batched_rotary_embedding_kernel(
104-
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
105-
// [num_tokens]
106-
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
107-
// head_size] or [num_tokens, num_heads,
108-
// head_size]
109-
scalar_t* __restrict__ key, // nullptr or
110-
// [batch_size, seq_len, num_kv_heads,
111-
// head_size] or [num_tokens, num_kv_heads,
112-
// head_size]
113-
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
114-
// 2]
115-
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
116-
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
117-
const int64_t head_stride, const int num_heads, const int num_kv_heads,
118-
const int head_size) {
119-
// Each thread block is responsible for one token.
120-
const int token_idx = blockIdx.x;
121-
int64_t pos = positions[token_idx];
122-
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
123-
const scalar_t* cache_ptr =
124-
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
125-
126-
apply_rotary_embedding<scalar_t, IS_NEOX>(
127-
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
128-
token_idx, query_stride, key_stride, head_stride);
129-
}
130-
131102
} // namespace vllm
132103

133104
void rotary_embedding(
@@ -211,96 +182,3 @@ void rotary_embedding(
211182
}
212183
});
213184
}
214-
215-
/*
216-
Batched version of rotary embedding, pack multiple LoRAs together
217-
and process in batched manner.
218-
*/
219-
void batched_rotary_embedding(
220-
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
221-
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
222-
// [num_tokens, num_heads * head_size] or
223-
// [batch_size, seq_len, num_heads, head_size] or
224-
// [num_tokens, num_heads, head_size]
225-
std::optional<torch::Tensor>
226-
key, // null or
227-
// [batch_size, seq_len, num_kv_heads * head_size] or
228-
// [num_tokens, num_kv_heads * head_size] or
229-
// [batch_size, seq_len, num_heads, head_size] or
230-
// [num_tokens, num_heads, head_size]
231-
int64_t head_size,
232-
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
233-
bool is_neox, int64_t rot_dim,
234-
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
235-
) {
236-
// num_tokens = batch_size * seq_len
237-
int64_t num_tokens = cos_sin_cache_offsets.size(0);
238-
TORCH_CHECK(
239-
positions.size(0) == num_tokens || positions.numel() == num_tokens,
240-
"positions must have the same num_tokens or batch_size as "
241-
"cos_sin_cache_offsets");
242-
243-
int positions_ndim = positions.dim();
244-
// Make sure num_tokens dim is consistent across positions, query, and key
245-
TORCH_CHECK(
246-
positions_ndim == 1 || positions_ndim == 2,
247-
"positions must have shape [num_tokens] or [batch_size, seq_len]");
248-
if (positions_ndim == 1) {
249-
TORCH_CHECK(query.size(0) == positions.size(0) &&
250-
(!key.has_value() || key->size(0) == positions.size(0)),
251-
"query, key and positions must have the same number of tokens");
252-
}
253-
if (positions_ndim == 2) {
254-
TORCH_CHECK(
255-
query.size(0) == positions.size(0) &&
256-
(!key.has_value() || key->size(0) == positions.size(0)) &&
257-
query.size(1) == positions.size(1) &&
258-
(!key.has_value() || key->size(1) == positions.size(1)),
259-
"query, key and positions must have the same batch_size and seq_len");
260-
}
261-
262-
// Make sure head_size is valid for query and key
263-
int query_hidden_size = query.numel() / num_tokens;
264-
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
265-
TORCH_CHECK(query_hidden_size % head_size == 0);
266-
TORCH_CHECK(key_hidden_size % head_size == 0);
267-
268-
// Make sure query and key have concistent number of heads
269-
int num_heads = query_hidden_size / head_size;
270-
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
271-
TORCH_CHECK(num_heads % num_kv_heads == 0);
272-
273-
int seq_dim_idx = positions_ndim - 1;
274-
int64_t query_stride = query.stride(seq_dim_idx);
275-
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
276-
// Determine head stride: for [*, heads, head_size] use stride of last dim;
277-
// for flat [*, heads*head_size], heads blocks are contiguous of size
278-
// head_size
279-
int query_ndim = query.dim();
280-
int64_t head_stride =
281-
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
282-
283-
dim3 grid(num_tokens);
284-
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
285-
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
286-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
287-
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
288-
if (is_neox) {
289-
vllm::batched_rotary_embedding_kernel<scalar_t, true>
290-
<<<grid, block, 0, stream>>>(
291-
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
292-
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
293-
cos_sin_cache.data_ptr<scalar_t>(),
294-
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
295-
key_stride, head_stride, num_heads, num_kv_heads, head_size);
296-
} else {
297-
vllm::batched_rotary_embedding_kernel<scalar_t, false>
298-
<<<grid, block, 0, stream>>>(
299-
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
300-
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
301-
cos_sin_cache.data_ptr<scalar_t>(),
302-
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
303-
key_stride, head_stride, num_heads, num_kv_heads, head_size);
304-
}
305-
});
306-
}

csrc/torch_bindings.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
214214
" Tensor cos_sin_cache, bool is_neox) -> ()");
215215
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
216216

217-
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
218-
// (supports multiple loras).
219-
ops.def(
220-
"batched_rotary_embedding(Tensor positions, Tensor! query,"
221-
" Tensor!? key, int head_size,"
222-
" Tensor cos_sin_cache, bool is_neox,"
223-
" int rot_dim,"
224-
" Tensor cos_sin_cache_offsets) -> ()");
225-
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
226-
227217
// Quantization ops
228218
#ifndef USE_ROCM
229219
// Quantized GEMM for AWQ.

tests/kernels/core/test_pos_encoding.py

Lines changed: 1 addition & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from itertools import accumulate, product
4+
from itertools import product
55
from typing import Callable, Optional
66

77
import pytest
@@ -111,151 +111,6 @@ def test_rotary_embedding(
111111
"expected returned key to be None"
112112

113113

114-
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
115-
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
116-
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
117-
@pytest.mark.parametrize("seq_len", SEQ_LENS)
118-
@pytest.mark.parametrize("num_heads", NUM_HEADS)
119-
@pytest.mark.parametrize("head_size", HEAD_SIZES)
120-
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
121-
@pytest.mark.parametrize("dtype", DTYPES)
122-
@pytest.mark.parametrize("seed", SEEDS)
123-
@pytest.mark.parametrize("device", CUDA_DEVICES)
124-
@pytest.mark.parametrize("use_key", USE_KEY)
125-
@torch.inference_mode()
126-
def test_batched_rotary_embedding(
127-
is_neox_style: bool,
128-
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
129-
batch_size: int,
130-
seq_len: int,
131-
num_heads: int,
132-
head_size: int,
133-
rotary_dim: Optional[int],
134-
dtype: torch.dtype,
135-
seed: int,
136-
device: str,
137-
use_key: bool,
138-
max_position: int = 8192,
139-
base: float = 10000,
140-
) -> None:
141-
current_platform.seed_everything(seed)
142-
torch.set_default_device(device)
143-
if rotary_dim is None:
144-
rotary_dim = head_size
145-
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
146-
"rope_type": "linear",
147-
"factor": (1, )
148-
})
149-
rope = rope.to(dtype=dtype, device=torch.get_default_device())
150-
151-
positions = torch.randint(0, max_position, (batch_size, seq_len))
152-
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
153-
query = torch.randn(query_shape, dtype=dtype)
154-
key = torch.randn_like(query) if use_key else None
155-
156-
# slice tensor if required, noop otherwise
157-
query = query[..., :head_size]
158-
key = key[..., :head_size] if use_key else None
159-
160-
# NOTE(woosuk): The reference implementation should be executed first
161-
# because the custom kernel is in-place.
162-
ref_query, ref_key = rope.forward_native(positions, query, key)
163-
out_query, out_key = rope.forward(positions,
164-
query,
165-
key,
166-
offsets=torch.zeros(batch_size * seq_len,
167-
dtype=torch.long,
168-
device=device))
169-
# Compare the results.
170-
torch.testing.assert_close(out_query,
171-
ref_query,
172-
atol=get_default_atol(out_query),
173-
rtol=get_default_rtol(out_query))
174-
if use_key:
175-
torch.testing.assert_close(out_key,
176-
ref_key,
177-
atol=get_default_atol(out_key),
178-
rtol=get_default_rtol(out_key))
179-
else:
180-
assert ref_key is None and out_key is None, \
181-
"expected returned key to be None"
182-
183-
184-
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
185-
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
186-
@pytest.mark.parametrize("seq_len", SEQ_LENS)
187-
@pytest.mark.parametrize("num_heads", NUM_HEADS)
188-
@pytest.mark.parametrize("head_size", HEAD_SIZES)
189-
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
190-
@pytest.mark.parametrize("dtype", DTYPES)
191-
@pytest.mark.parametrize("seed", SEEDS)
192-
@pytest.mark.parametrize("device", CUDA_DEVICES)
193-
@pytest.mark.parametrize("use_key", USE_KEY)
194-
@torch.inference_mode()
195-
def test_batched_rotary_embedding_multi_lora(
196-
is_neox_style: bool,
197-
batch_size: int,
198-
seq_len: int,
199-
num_heads: int,
200-
head_size: int,
201-
rotary_dim: Optional[int],
202-
dtype: torch.dtype,
203-
seed: int,
204-
device: str,
205-
use_key: bool,
206-
max_position: int = 8192,
207-
base: float = 10000,
208-
) -> None:
209-
current_platform.seed_everything(seed)
210-
torch.set_default_device(device)
211-
if rotary_dim is None:
212-
rotary_dim = head_size
213-
scaling_factors: list[int] = [1, 2, 4]
214-
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
215-
"rope_type": "linear",
216-
"factor": tuple(scaling_factors)
217-
})
218-
rope = rope.to(dtype=dtype, device=torch.get_default_device())
219-
220-
positions = torch.randint(0, max_position, (batch_size, seq_len))
221-
query = torch.randn(batch_size,
222-
seq_len,
223-
num_heads * head_size,
224-
dtype=dtype)
225-
key = torch.randn_like(query) if use_key else None
226-
227-
offset_map = torch.tensor(
228-
list(
229-
accumulate([0] + [
230-
max_position * scaling_factor * 2
231-
for scaling_factor in scaling_factors[:-1]
232-
])))
233-
query_types = torch.randint(0,
234-
len(scaling_factors), (batch_size, seq_len),
235-
device=device)
236-
query_offsets = offset_map[query_types]
237-
238-
# NOTE(woosuk): The reference implementation should be executed first
239-
# because the custom kernel is in-place.
240-
ref_query, ref_key = rope.forward_native(positions, query, key,
241-
query_offsets)
242-
out_query, out_key = rope.forward(positions, query, key,
243-
query_offsets.flatten())
244-
# Compare the results.
245-
torch.testing.assert_close(out_query,
246-
ref_query,
247-
atol=get_default_atol(out_query),
248-
rtol=get_default_rtol(out_query))
249-
if use_key:
250-
torch.testing.assert_close(out_key,
251-
ref_key,
252-
atol=get_default_atol(out_key),
253-
rtol=get_default_rtol(out_key))
254-
else:
255-
assert ref_key is None and out_key is None, \
256-
"expected returned key to be None"
257-
258-
259114
@torch.inference_mode()
260115
def test_rope_module_cache():
261116
MAX_POSITIONS = [123, 1234]

tests/kernels/core/test_rotary_embedding.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,14 @@
1616
def rotary_embedding_opcheck(rot,
1717
positions: torch.Tensor,
1818
query: torch.Tensor,
19-
key: Optional[torch.Tensor] = None,
20-
offsets: Optional[torch.Tensor] = None):
19+
key: Optional[torch.Tensor] = None):
2120
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
2221

23-
# ops.rotary_embedding()/batched_rotary_embedding()
24-
# are in-place operations that update the query and key tensors.
25-
if offsets is not None:
26-
opcheck(torch.ops._C.batched_rotary_embedding,
27-
(positions, query, key, rot.head_size, cos_sin_cache,
28-
rot.is_neox_style, rot.rotary_dim, offsets))
29-
else:
30-
opcheck(torch.ops._C.rotary_embedding,
31-
(positions, query, key, rot.head_size, cos_sin_cache,
32-
rot.is_neox_style))
22+
# ops.rotary_embedding() is a in-place operation
23+
# that updates the query and key tensors.
24+
opcheck(torch.ops._C.rotary_embedding,
25+
(positions, query, key, rot.head_size, cos_sin_cache,
26+
rot.is_neox_style))
3327

3428

3529
@pytest.mark.parametrize("device", ["cuda"])
@@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
6559
key = key[..., :head_size] if use_key else None
6660

6761
rotary_embedding_opcheck(rot, positions, query, key)
68-
offsets = torch.zeros(batch_size * seq_len,
69-
device=device,
70-
dtype=torch.long)
71-
rotary_embedding_opcheck(rot, positions, query, key, offsets)
7262

7363
# if we have a contiguous head stride, test the alternate
7464
# [..., num_heads * head_dim] shape/layout

vllm/_custom_ops.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,6 @@ def rotary_embedding(
257257
cos_sin_cache, is_neox)
258258

259259

260-
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
261-
key: Optional[torch.Tensor], head_size: int,
262-
cos_sin_cache: torch.Tensor, is_neox: bool,
263-
rot_dim: int,
264-
cos_sin_cache_offsets: torch.Tensor) -> None:
265-
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
266-
cos_sin_cache, is_neox, rot_dim,
267-
cos_sin_cache_offsets)
268-
269-
270260
# layer norm ops
271261
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
272262
epsilon: float) -> None:

vllm/_ipex_ops.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,6 @@ def rotary_embedding(
148148
head_size, cos_sin_cache,
149149
is_neox, rot_dim)
150150

151-
@staticmethod
152-
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
153-
key: torch.Tensor, head_size: int,
154-
cos_sin_cache: torch.Tensor, is_neox: bool,
155-
rot_dim: int,
156-
cos_sin_cache_offsets: torch.Tensor) -> None:
157-
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
158-
head_size, cos_sin_cache,
159-
is_neox, rot_dim,
160-
cos_sin_cache_offsets)
161-
162151
@staticmethod
163152
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
164153
epsilon: float) -> torch.Tensor:

0 commit comments

Comments
 (0)