Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: non-contiguous query with paged kv cache #553

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get q_stride_n and q_stride_h
const auto q_stride_n = q.stride(0);
const auto q_stride_h = q.stride(1);

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
Expand Down Expand Up @@ -157,8 +161,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
sm_scale, rope_scale, rope_theta);
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);

DTypeO* tmp_v = nullptr;
float* tmp_s = nullptr;
Expand Down
12 changes: 7 additions & 5 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get q_stride_n and q_stride_h
const auto q_stride_n = q.stride(0);
const auto q_stride_h = q.stride(1);

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
Expand All @@ -254,8 +258,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
LinHeLurking marked this conversation as resolved.
Show resolved Hide resolved
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand All @@ -266,7 +269,6 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
get_variant_code(/*use_custom_mask=*/MASK_MODE == MaskMode::kCustom,
/*use_sliding_window=*/true, USE_LOGITS_SOFT_CAP,
/*use_alibi_slopes=*/false)>;

PagedParamsT params(
static_cast<DTypeQ*>(q.data_ptr()), paged_kv,
maybe_custom_mask.has_value() ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr())
Expand All @@ -276,8 +278,8 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
: nullptr,
/*q_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, sm_scale,
rope_scale, rope_theta);
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
logits_soft_cap, sm_scale, rope_scale, rope_theta);

DTypeO* tmp_v = nullptr;
float* tmp_s = nullptr;
Expand Down
6 changes: 4 additions & 2 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
vec_t<float, vec_size> q_vec;
vec_t<float, vec_size> freq;
int32_t q_offset_val = q_offset == nullptr ? (kv_len - 1) : q_offset[batch_idx];
const uint32_t q_stride_n = params.q_stride_n;
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
const uint32_t q_stride_h = params.q_stride_h;
if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) {
const float rope_rcp_scale = params.rope_rcp_scale;
const float rope_rcp_theta = params.rope_rcp_theta;
Expand All @@ -450,10 +452,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
}
// apply rotary embedding to q matrix
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val);
q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, q_offset_val);
} else {
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size);
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
Expand Down
9 changes: 7 additions & 2 deletions include/flashinfer/attention/decode_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ struct BatchDecodeParams {
float* alibi_slopes;
uint32_t padded_batch_size;
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
int32_t window_left;
float logits_soft_cap;
float sm_scale;
Expand All @@ -135,8 +137,9 @@ struct BatchDecodeParams {
__device__ __host__ BatchDecodeParams(DTypeQ* q, IdType* q_offset,
paged_kv_t<DTypeKV, IdType> paged_kv, DTypeO* o, float* lse,
float* alibi_slopes, uint32_t num_qo_heads,
int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta)
IdType q_stride_n, IdType q_stride_h, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta)
: q(q),
q_offset(q_offset),
paged_kv(paged_kv),
Expand All @@ -145,6 +148,8 @@ struct BatchDecodeParams {
alibi_slopes(alibi_slopes),
padded_batch_size(0),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
window_left(window_left),
logits_soft_cap(logits_soft_cap),
sm_scale(sm_scale),
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1867,7 +1867,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
const uint32_t qo_packed_idx_base =
(qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q<NUM_WARPS_Q, NUM_WARPS_KV>()) * NUM_FRAGS_Q *
16;
const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim;
const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h;
constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B;
smem_t<swizzle_mode_q> qo_smem(smem);
DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size,
Expand Down
10 changes: 7 additions & 3 deletions include/flashinfer/attention/prefill_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ struct BatchPrefillPagedParams {
float* lse;
float* alibi_slopes;
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
int32_t window_left;
float logits_soft_cap;
float sm_scale;
Expand All @@ -232,9 +234,9 @@ struct BatchPrefillPagedParams {
__host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t<DTypeKV, IdType> paged_kv,
uint8_t* custom_mask, IdType* q_indptr, IdType* qk_indptr,
IdType* q_offset, DTypeO* o, float* lse, float* alibi_slopes,
uint32_t num_qo_heads, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta)
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta)
: q(q),
paged_kv(paged_kv),
custom_mask(custom_mask),
Expand All @@ -245,6 +247,8 @@ struct BatchPrefillPagedParams {
lse(lse),
alibi_slopes(alibi_slopes),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
window_left(window_left),
logits_soft_cap(logits_soft_cap),
sm_scale(sm_scale),
Expand Down
5 changes: 4 additions & 1 deletion python/flashinfer/jit/batch_decode_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@

void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());

const auto q_stride_n = q.stride(0);
const auto q_stride_h = q.stride(1);

const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
Expand All @@ -121,7 +124,7 @@
/*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
{% if use_alibi == "true" %}static_cast<float*>(alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);

{{ dtype_o }}* tmp_v = nullptr;
float* tmp_s = nullptr;
Expand Down
5 changes: 4 additions & 1 deletion python/flashinfer/jit/batch_prefill_templ.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@

void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
void* int_buffer_ptr = static_cast<void*>(int_workspace_buffer.data_ptr());

const auto q_stride_n = q.stride(0);
const auto q_stride_h = q.stride(1);

const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
Expand All @@ -221,7 +224,7 @@
static_cast<{{ dtype_o }}*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
{% if use_alibi == "true" %}static_cast<float*>(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);

{{ dtype_o }}* tmp_v = nullptr;
float* tmp_s = nullptr;
Expand Down
77 changes: 77 additions & 0 deletions tests/test_non_contiguous_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import pytest
import flashinfer


@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.parametrize("page_size", [1, 5])
@pytest.mark.parametrize("seq_len", [1])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("num_qo_heads", [4, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
def test_batch_paged_decode_packed_input(
batch_size,
page_size,
seq_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
nnz = batch_size * seq_len
num_pages_per_req = (seq_len + page_size - 1) // page_size
num_pages = batch_size * num_pages_per_req
last_page_len = (seq_len - 1) % page_size + 1
k_cache = torch.randn(
size=(num_pages, page_size, num_kv_heads, head_dim),
dtype=torch.float16,
device="cuda:0",
)
v_cache = torch.randn_like(k_cache)
paged_kv_cache = (k_cache, v_cache)
workspace_buffer = torch.empty(
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
)
paged_kv_indptr = torch.tensor(
[i * num_pages_per_req for i in range(batch_size + 1)],
dtype=torch.int32,
device="cuda:0",
)
paged_kv_indices = torch.tensor(
list(range(num_pages)), dtype=torch.int32, device="cuda:0"
)
paged_kv_last_page_len = torch.tensor(
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0"
)

wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
wrapper.plan(
indptr=paged_kv_indptr,
indices=paged_kv_indices,
last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
)

qkv_packed = torch.randn(
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim),
dtype=torch.float16,
device="cuda:0",
)
qkv_split_idx = (
num_qo_heads * head_dim,
num_kv_heads * head_dim,
num_kv_heads * head_dim,
)
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1)
q = q.view(-1, num_qo_heads, head_dim)
o_packed = wrapper.run(q, paged_kv_cache)
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache)
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
test_batch_paged_decode_packed_input(37, 127, 1, 4, 64, 128)
79 changes: 79 additions & 0 deletions tests/test_non_contiguous_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,85 @@ def test_batch_ragged_prefill_packed_input(
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.parametrize("page_size", [1, 5])
@pytest.mark.parametrize("seq_len", [1, 7, 127, 257])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("num_qo_heads", [4, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("causal", [True, False])
def test_batch_paged_prefill_packed_input(
batch_size,
page_size,
seq_len,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")

nnz = batch_size * seq_len
num_pages_per_req = (seq_len + page_size - 1) // page_size
num_pages = batch_size * num_pages_per_req
last_page_len = (seq_len - 1) % page_size + 1
k_cache = torch.randn(
size=(num_pages, page_size, num_kv_heads, head_dim),
dtype=torch.float16,
device="cuda:0",
)
v_cache = torch.randn_like(k_cache)
paged_kv_cache = (k_cache, v_cache)
workspace_buffer = torch.empty(
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
)
qo_indptr = torch.tensor(
[i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
)
paged_kv_indptr = torch.tensor(
[i * num_pages_per_req for i in range(batch_size + 1)],
dtype=torch.int32,
device="cuda:0",
)
paged_kv_indices = torch.tensor(
list(range(num_pages)), dtype=torch.int32, device="cuda:0"
)
paged_kv_last_page_len = torch.tensor(
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0"
)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer)
wrapper.plan(
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=page_size,
causal=causal,
)

qkv_packed = torch.randn(
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim),
dtype=torch.float16,
device="cuda:0",
)
qkv_split_idx = (
num_qo_heads * head_dim,
num_kv_heads * head_dim,
num_kv_heads * head_dim,
)
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1)
# pretend that we have already appended k/v to paged_kv table
q = q.view(-1, num_qo_heads, head_dim)
o_packed = wrapper.run(q, paged_kv_cache)
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache)
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
test_single_prefill_packed_input(127, 4, 4, 64, True)
test_batch_ragged_prefill_packed_input(37, 127, 4, 4, 64, True)
test_batch_paged_prefill_packed_input(37, 5, 127, 4, 4, 64, True)