Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RoPE position info in batch prefill/decode kernels #69

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions include/flashinfer/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ template <bool partition_kv, RotaryMode rotary_mode, uint32_t num_stages_smem,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* __restrict__ q, IdType* __restrict__ q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
Expand All @@ -520,6 +521,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
: 0;
const uint32_t seq_len =
partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len;
const uint32_t mapped_batch_idx =
partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx;

extern __shared__ uint8_t smem[];
DTypeIn* k_smem = (DTypeIn*)smem;
Expand All @@ -541,23 +544,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}
// apply rotary embedding to q matrix
if constexpr (partition_kv) {
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim,
freq, seq_len - 1);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, seq_len - 1);
}
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq,
q_rope_position == nullptr ? (seq_len - 1) : q_rope_position[mapped_batch_idx]);
} else {
// do not apply rotary embedding to q matrix
if constexpr (partition_kv) {
q_vec.cast_load(
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim +
tx * vec_size);
} else {
q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
}
q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
}
block.sync();

Expand Down Expand Up @@ -627,7 +619,9 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
block.sync();
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
block.sync();

Expand Down Expand Up @@ -1120,7 +1114,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
RotaryMode ROTARY_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
float rope_scale, float rope_theta, cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
Expand Down Expand Up @@ -1153,6 +1148,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
(void*)&q_rope_position,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
Expand All @@ -1171,6 +1167,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
void* args[] = {(void*)&q,
(void*)&q_rope_position,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
Expand Down Expand Up @@ -1212,7 +1209,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCache(
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
Expand All @@ -1228,13 +1226,12 @@ cudaError_t BatchDecodeWithPagedKVCache(

DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, ROTARY_MODE, DTypeIn, DTypeOut,
IdType>(
q, paged_kv, kv_partition_info, o, tmp, lse, rope_scale, rope_theta, stream);
})})});
{DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn,
DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o,
tmp, lse, rope_scale, rope_theta, stream);
})})});

return cudaSuccess;
}
Expand Down
20 changes: 15 additions & 5 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ struct paged_kv_t {
IdType* indptr;
// [batch_size] The offset of the last page for each request in the batch
IdType* last_page_len;
// [batch_size] The start position of each request in the batch.
IdType* rope_pos_offset;

/*!
* \brief Construct an empty paged key-value cache
Expand All @@ -101,7 +103,8 @@ struct paged_kv_t {
indices(nullptr),
ptrs(nullptr),
indptr(nullptr),
last_page_len(nullptr) {}
last_page_len(nullptr),
rope_pos_offset(nullptr) {}

/*!
* \brief Construct a paged key-value cache
Expand All @@ -113,20 +116,23 @@ struct paged_kv_t {
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType* data, IdType* indices, IdType* indptr,
IdType* last_page_len)
IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
data(data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len) {}
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {}

/*!
* \brief Construct a paged key-value cache
Expand All @@ -137,18 +143,22 @@ struct paged_kv_t {
* \param ptrs The array of pointers to each active page
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType** ptrs, IdType* indptr,
IdType* last_page_len)
IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
indptr(indptr) {}
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {}

/*!
* \brief Compute the offset of k element in the allocated buffer.
Expand Down
Loading