Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size, torch::Tensor& cos_sin_cache,
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
Expand Down
122 changes: 0 additions & 122 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
token_idx, query_stride, key_stride, head_stride);
}

template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;

apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride, head_stride);
}

} // namespace vllm

void rotary_embedding(
Expand Down Expand Up @@ -211,96 +182,3 @@ void rotary_embedding(
}
});
}

/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std::optional<torch::Tensor>
key, // null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0);
TORCH_CHECK(
positions.size(0) == num_tokens || positions.numel() == num_tokens,
"positions must have the same num_tokens or batch_size as "
"cos_sin_cache_offsets");

int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) &&
(!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len");
}

// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);

int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
10 changes: 0 additions & 10 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);

// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AWQ.
Expand Down
147 changes: 1 addition & 146 deletions tests/kernels/core/test_pos_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from itertools import accumulate, product
from itertools import product
from typing import Callable, Optional

import pytest
Expand Down Expand Up @@ -111,151 +111,6 @@ def test_rotary_embedding(
"expected returned key to be None"


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype, device=torch.get_default_device())

positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None

# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=torch.long,
device=device))
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: list[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype, device=torch.get_default_device())

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) if use_key else None

offset_map = torch.tensor(
list(
accumulate([0] + [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
query_offsets = offset_map[query_types]

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"


@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
Expand Down
22 changes: 6 additions & 16 deletions tests/kernels/core/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,14 @@
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None):
key: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)

# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
opcheck(torch.ops._C.batched_rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style, rot.rotary_dim, offsets))
else:
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))
# ops.rotary_embedding() is a in-place operation
# that updates the query and key tensors.
opcheck(torch.ops._C.rotary_embedding,
(positions, query, key, rot.head_size, cos_sin_cache,
rot.is_neox_style))


@pytest.mark.parametrize("device", ["cuda"])
Expand Down Expand Up @@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
key = key[..., :head_size] if use_key else None

rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)

# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
Expand Down
10 changes: 0 additions & 10 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,6 @@ def rotary_embedding(
cos_sin_cache, is_neox)


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor], head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)


# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
Expand Down
11 changes: 0 additions & 11 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,6 @@ def rotary_embedding(
head_size, cos_sin_cache,
is_neox, rot_dim)

@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
head_size, cos_sin_cache,
is_neox, rot_dim,
cos_sin_cache_offsets)

@staticmethod
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> torch.Tensor:
Expand Down
Loading