From d89fe51633777bc9966e65d8c54c9c9109ae355f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 12 Jan 2024 23:39:12 -0500 Subject: [PATCH] Support RoPE position info in batch prefill/decode kernels This PR adds q/k position information to batch prefill/decode kernels. More specifically, the kernel now accepts two additional arrays: * `q_rope_position` with shape `(total_q_len,)`, denoting the in-sequence position of each position in the input q. * `k_rope_pos_offset` with shape `(num_sequence,)`, denoting the start position of each sequence in k. These two arrays helps on-the-fly calculate RoPE in multi-level cases. Tests `test_batch_prefill` and `test_batch_decode` can pass. Performance is not validated yet. Per discussion with Zihao, this change is not very likely to incur significant perf regression. --- include/flashinfer/decode.cuh | 49 +++++----- include/flashinfer/page.cuh | 20 +++- include/flashinfer/prefill.cuh | 171 +++++++++++++++++++++++++-------- include/flashinfer/wrapper.cuh | 27 +++--- python/csrc/batch_decode.cu | 2 +- python/csrc/batch_prefill.cu | 4 +- src/bench_batch_decode.cu | 17 ++-- src/test_batch_decode.cu | 16 +-- src/test_batch_prefill.cu | 6 +- src/tvm_wrapper.cu | 97 ++++++++++++++----- 10 files changed, 282 insertions(+), 127 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index ef96c7dd..c4a17b4b 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -497,7 +497,8 @@ template __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, paged_kv_t paged_kv, + DTypeIn* __restrict__ q, IdType* __restrict__ q_rope_position, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { @@ -520,6 +521,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( : 0; const uint32_t seq_len = partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len; + const uint32_t mapped_batch_idx = + partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx; extern __shared__ uint8_t smem[]; DTypeIn* k_smem = (DTypeIn*)smem; @@ -541,23 +544,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix - if constexpr (partition_kv) { - q_vec = vec_apply_llama_rope( - q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim, - freq, seq_len - 1); - } else { - q_vec = vec_apply_llama_rope( - q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, seq_len - 1); - } + q_vec = vec_apply_llama_rope( + q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, + q_rope_position == nullptr ? (seq_len - 1) : q_rope_position[mapped_batch_idx]); } else { // do not apply rotary embedding to q matrix - if constexpr (partition_kv) { - q_vec.cast_load( - q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim + - tx * vec_size); - } else { - q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); - } + q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); } block.sync(); @@ -627,7 +619,9 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( block.sync(); compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, - freq, cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz, + freq, + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + + cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz, iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st); block.sync(); @@ -1120,7 +1114,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( template cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, paged_kv_t paged_kv, + DTypeIn* q, IdType* q_rope_position, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, float rope_scale, float rope_theta, cudaStream_t stream) { const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); @@ -1153,6 +1148,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, + (void*)&q_rope_position, (void*)&paged_kv, (void*)&kv_partition_info, (void*)&o, @@ -1171,6 +1167,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, + (void*)&q_rope_position, (void*)&paged_kv, (void*)&kv_partition_info, (void*)&o, @@ -1212,7 +1209,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( template cudaError_t BatchDecodeWithPagedKVCache( - DTypeIn* q, paged_kv_t paged_kv, + DTypeIn* q, IdType* q_rope_position, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -1228,13 +1226,12 @@ cudaError_t BatchDecodeWithPagedKVCache( DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_HEAD_DIM( - head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { - return BatchDecodeWithPagedKVCacheDispatched( - q, paged_kv, kv_partition_info, o, tmp, lse, rope_scale, rope_theta, stream); - })})}); + {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + return BatchDecodeWithPagedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn, + DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o, + tmp, lse, rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index baa53e5d..80e0a8fd 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -88,6 +88,8 @@ struct paged_kv_t { IdType* indptr; // [batch_size] The offset of the last page for each request in the batch IdType* last_page_len; + // [batch_size] The start position of each request in the batch. + IdType* rope_pos_offset; /*! * \brief Construct an empty paged key-value cache @@ -101,7 +103,8 @@ struct paged_kv_t { indices(nullptr), ptrs(nullptr), indptr(nullptr), - last_page_len(nullptr) {} + last_page_len(nullptr), + rope_pos_offset(nullptr) {} /*! * \brief Construct a paged key-value cache @@ -113,12 +116,14 @@ struct paged_kv_t { * \param indices The page indices array * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch + * \param rope_pos_offset The start position of each request in the batch. * \note This constructor should only be used when page_storage == kIndices */ __host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, DType* data, IdType* indices, IdType* indptr, - IdType* last_page_len) + IdType* last_page_len, + IdType* rope_pos_offset = nullptr) : num_heads(num_heads), page_size(page_size), head_dim(head_dim), @@ -126,7 +131,8 @@ struct paged_kv_t { data(data), indices(indices), indptr(indptr), - last_page_len(last_page_len) {} + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) {} /*! * \brief Construct a paged key-value cache @@ -137,18 +143,22 @@ struct paged_kv_t { * \param ptrs The array of pointers to each active page * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch + * \param rope_pos_offset The start position of each request in the batch. * \note This constructor should only be used when page_storage == kIndices */ __host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, DType** ptrs, IdType* indptr, - IdType* last_page_len) + IdType* last_page_len, + IdType* rope_pos_offset = nullptr) : num_heads(num_heads), page_size(page_size), head_dim(head_dim), batch_size(batch_size), ptrs(ptrs), - indptr(indptr) {} + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) {} /*! * \brief Compute the offset of k element in the allocated buffer. diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index da7c12e3..069c7eb9 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -98,6 +98,38 @@ __device__ __forceinline__ void frag_apply_llama_rope(T* x_first_half, T* x_seco } } +template +__device__ __forceinline__ void frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, + const float* rope_freq, + uint32_t offset, + const IdType* q_rope_position, + float scale = 1.f) { + float pos[2] = {static_cast(q_rope_position[offset]), + static_cast(q_rope_position[offset + (8 / group_size)])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + uint32_t i, j; + if constexpr (frag_layout == FragLayout::kRowMajor) { + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + i = ((reg_id % 4) / 2); + j = (reg_id / 4); + } else { + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + i = reg_id / 4; + j = (reg_id % 4) / 2; + } + __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin) * scale; + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin) * scale; + } +} + /*! * \brief Produce k/v fragments from global memory to shared memory. * \tparam fill_mode The fill mode of the shared memory. @@ -308,6 +340,39 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_in; } +template +__device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + const uint32_t q_idx_base, const IdType* q_rope_position, smem_t* q_smem, + uint32_t* q_smem_offset_r, float (*rope_freq)[4], const float sm_scale) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x; + uint32_t q_frag_local[2][4]; + static_assert(num_frags_y % 4 == 0, "num_frags_y must be a multiple of 4"); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + uint32_t q_idx = q_idx_base + (fx * 16 + tx / 4) / group_size; + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t fyi = 0; fyi < num_frags_y / 2; ++fyi) { + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); + frag_apply_llama_rope_with_pos( + (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], q_idx, + q_rope_position, sm_scale); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->advance_offset_by_column<2>(q_smem_offset_r_first_half, fyi); + } + *q_smem_offset_r += 16 * channel_size_128b_in; + } + *q_smem_offset_r -= num_frags_x * 16 * channel_size_128b_in; +} + template __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(smem_t* q_smem, const float sm_scale) { @@ -946,8 +1011,9 @@ template (qo_idx_base, qo_len, kv_len, &qo_smem, - &q_smem_offset_r, rope_freq, sm_scale); + if (!q_rope_position) { + q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, + &q_smem_offset_r, rope_freq, sm_scale); + } else { + q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + qo_indptr[request_idx] + qo_idx_base, q_rope_position, &qo_smem, &q_smem_offset_r, + rope_freq, sm_scale); + } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } @@ -1057,7 +1130,9 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( if constexpr (rotary_mode == RotaryMode::kLlama) { k_smem_inplace_apply_rotary( - iter * 16 * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); + (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + + iter * 16 * num_frags_z, + &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } @@ -1126,9 +1201,9 @@ template paged_kv, - IdType* __restrict__ qo_indptr, DTypeOut* __restrict__ o, float* __restrict__ tmp, - float* __restrict__ lse, float sm_scale, const float log2_rope_rcp_scale, - const float log2_rope_rcp_theta) { + IdType* __restrict__ qo_indptr, IdType* __restrict__ q_rope_position, DTypeOut* __restrict__ o, + float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, + const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); sm_scale *= math::log2e; @@ -1186,9 +1261,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( block.sync(); if constexpr (rotary_mode == RotaryMode::kLlama) { - q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, - &q_smem_offset_r, rope_freq, sm_scale); + if (q_rope_position == nullptr) { + q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, + &q_smem_offset_r, rope_freq, sm_scale); + } else { + q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + qo_indptr[request_idx] + qo_idx_base, q_rope_position, &qo_smem, &q_smem_offset_r, + rope_freq, sm_scale); + } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); } @@ -1230,7 +1312,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (rotary_mode == RotaryMode::kLlama) { k_smem_inplace_apply_rotary( - iter * 16 * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + + iter * 16 * num_frags_z, + &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } @@ -1580,9 +1664,10 @@ template cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, - const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float rope_scale, - const float rope_theta, cudaStream_t stream = nullptr) { + DTypeIn* v, IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, + float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, + const uint32_t num_kv_heads, const float rope_scale, const float rope_theta, + cudaStream_t stream = nullptr) { const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); @@ -1627,6 +1712,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&k, (void*)&v, (void*)&kv_indptr, + (void*)&q_rope_position, + (void*)&k_rope_pos_offset, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1641,12 +1728,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithRaggedKVCache( - DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, DTypeOut* o, - float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, - const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, - bool allow_fp16_qk_reduction = false, const float rope_scale = 1.f, - const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, + IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, + const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, + const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, + const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; uint32_t num_frags_x, num_qo_tiles; @@ -1683,9 +1770,9 @@ cudaError_t BatchPrefillWithRaggedKVCache( return BatchPrefillWithRaggedKVCacheDispatched< NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices_d, tile_indices_d, qo_indptr, k, v, kv_indptr, o, - tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale, - rope_theta, stream); + q, request_indices_d, tile_indices_d, qo_indptr, k, v, kv_indptr, + q_rope_position, k_rope_pos_offset, o, tmp, lse, batch_size, + num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream); })})})})})})}); FLASHINFER_CUDA_CALL(cudaFreeAsync(request_indices_d, stream)); @@ -1720,9 +1807,9 @@ template cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { + IdType* q_rope_position, paged_kv_t paged_kv, + DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale = 1.f, + float rope_theta = 1e4, cudaStream_t stream = nullptr) { constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD; const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; @@ -1748,8 +1835,9 @@ cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, o, tmp, lse, batch_size, - num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream); + q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, q_rope_position, + paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale, + rope_theta, stream); FLASHINFER_CUDA_CALL(cudaFreeAsync(keys, stream)); FLASHINFER_CUDA_CALL(cudaFreeAsync(values, stream)); @@ -1764,8 +1852,9 @@ template cudaError_t BatchPrefillWithPagedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, - paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_tiles, float rope_scale, float rope_theta, cudaStream_t stream) { + IdType* q_rope_position, paged_kv_t paged_kv, + DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale, float rope_theta, + cudaStream_t stream) { const float sm_scale = 1.f / std::sqrt(float(paged_kv.head_dim)); const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); @@ -1811,6 +1900,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&q, (void*)&paged_kv, (void*)&qo_indptr, + (void*)&q_rope_position, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1825,8 +1915,9 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCache( - DTypeIn* q, IdType* qo_indptr, paged_kv_t paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, bool causal = true, + DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + paged_kv_t paged_kv, DTypeOut* o, float* tmp, + float* lse, uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; @@ -1872,16 +1963,18 @@ cudaError_t BatchPrefillWithPagedKVCache( return BatchPrefillWithPagedKVCacheFallbackDispatched< page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, - DTypeOut, IdType>( - q, request_indices_d, tile_indices_d, qo_indptr, paged_kv, o, - tmp, lse, num_qo_tiles, rope_scale, rope_theta, stream); + DTypeOut, IdType>(q, request_indices_d, tile_indices_d, + qo_indptr, q_rope_position, paged_kv, o, + tmp, lse, num_qo_tiles, rope_scale, + rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices_d, tile_indices_d, qo_indptr, paged_kv, o, - tmp, lse, num_qo_tiles, rope_scale, rope_theta, stream); + q, request_indices_d, tile_indices_d, qo_indptr, + q_rope_position, paged_kv, o, tmp, lse, num_qo_tiles, + rope_scale, rope_theta, stream); } }) diff --git a/include/flashinfer/wrapper.cuh b/include/flashinfer/wrapper.cuh index e8c7a20f..64cd0956 100644 --- a/include/flashinfer/wrapper.cuh +++ b/include/flashinfer/wrapper.cuh @@ -46,7 +46,7 @@ namespace flashinfer { template cudaError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler* handler, DTypeIn* q, + BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -72,15 +72,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::runtime_error(err_msg.str()); } return BatchDecodeWithPagedKVCache( - q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, - rope_theta, stream); + q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, + rope_scale, rope_theta, stream); } template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; @@ -106,14 +106,14 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( return BatchPrefillWithPagedKVCacheFallbackDispatched< page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles, - rope_scale, rope_theta, stream); + q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, + num_qo_tiles, rope_scale, rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles, - rope_scale, rope_theta, stream); + q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, + num_qo_tiles, rope_scale, rope_theta, stream); } })}); return cudaSuccess; @@ -153,9 +153,9 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size, - const uint32_t num_kv_heads, const float rope_scale, const float rope_theta, - cudaStream_t stream) { + IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale, + const float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -177,8 +177,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( return BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, o, tmp, lse, batch_size, - num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream); + q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position, + k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale, + rope_theta, stream); }); return cudaSuccess; } diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index cbdcceb1..6ed077af 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -160,7 +160,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( static_cast(paged_kv_last_page_len.data_ptr())); cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &handler_, static_cast(q.data_ptr()), paged_kv, + &handler_, static_cast(q.data_ptr()), /*q_rope_position=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), num_qo_heads, RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/nullptr); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 47964a52..06a88eee 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -110,8 +110,8 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( &handler_, static_cast(q.data_ptr()), - static_cast(qo_indptr.data_ptr()), paged_kv, - static_cast(o.data_ptr()), + static_cast(qo_indptr.data_ptr()), + /*q_rope_position=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, rope_scale, rope_theta, /*stream=*/nullptr); diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 79ff571e..dcd1ee50 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -79,8 +79,8 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { head_dim, page_size, rotary_mode); state.exec([&](nvbench::launch&) { cudaError_t status = - BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), paged_kv, + BatchDecodeWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q.data()), /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, rotary_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); @@ -88,10 +88,11 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { }); } else { state.exec([&](nvbench::launch&) { - cudaError_t status = BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q.data()), paged_kv, kv_partition_info_t(), - thrust::raw_pointer_cast(o.data()), nullptr, - /*lse=*/nullptr, num_qo_heads, rotary_mode); + cudaError_t status = + BatchDecodeWithPagedKVCache( + thrust::raw_pointer_cast(q.data()), /*q_rope_position=*/nullptr, paged_kv, + kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), nullptr, + /*lse=*/nullptr, num_qo_heads, rotary_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } @@ -152,9 +153,9 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { qo_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - paged_kv, thrust::raw_pointer_cast(o.data()), + /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, /*causal=*/false, rotary_mode); }); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 84c200ff..917d123c 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -107,16 +107,16 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si if (!cooperative) { // use non-cooperative kernel cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q_device.data()), paged_kv, kv_partition_info_t(), - thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, /*lse=*/nullptr, - num_qo_heads, rotary_mode); + flashinfer::BatchDecodeWithPagedKVCache( + thrust::raw_pointer_cast(q_device.data()), /*q_rope_position=*/nullptr, paged_kv, + kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), + /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, rotary_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } else { - cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), paged_kv, - thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, rotary_mode); + cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), /*q_rope_position=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, rotary_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } // compare result diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index b3c247b7..69b8154a 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -101,7 +101,7 @@ void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), - thrust::raw_pointer_cast(q_indptr_device.data()), paged_kv, + thrust::raw_pointer_cast(q_indptr_device.data()), /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); @@ -216,7 +216,7 @@ void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), - paged_kv, thrust::raw_pointer_cast(o_device.data()), + /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); @@ -301,7 +301,7 @@ void _TestBatchPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t n auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), - paged_kv, thrust::raw_pointer_cast(o_device.data()), + /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 9e5d0f2b..86798240 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -184,7 +184,7 @@ thread_local BatchPrefillHandler batch_prefill_ragged_kv_handler; template cudaError_t _BatchPrefillWithPagedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, @@ -202,8 +202,8 @@ cudaError_t _BatchPrefillWithPagedKVCacheWrapper( return BatchPrefillWithPagedKVCacheWrapperDispatched< page_storage, kv_layout, GROUP_SIZE, /*head_dim=*/128, ROTARY_MODE, /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, paged_kv, o, lse, rope_scale, rope_theta, - stream); + handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale, + rope_theta, stream); })})}); return cudaSuccess; } @@ -214,6 +214,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q DLTensor* page_table_indptr, // DLTensor* page_table_values, // DLTensor* last_page_len, // + DLTensor* k_rope_pos_offset, // + DLTensor* q_rope_position, // DLTensor* output, // DLTensor* lse, // int64_t causal = 1, // @@ -229,6 +231,10 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q << "The device of page_table_values matrix must be CUDA."; CHECK_EQ(last_page_len->device.device_type, kDLCUDA) << "The device of last_page_len matrix must be CUDA."; + CHECK_EQ(q_rope_position->device.device_type, kDLCUDA) + << "The device of q_rope_position matrix must be CUDA."; + CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) + << "The device of k_rope_pos_offset matrix must be CUDA."; CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr matrix must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; @@ -238,18 +244,23 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(page_table_indptr->device.device_id, dev_id); CHECK_EQ(page_table_values->device.device_id, dev_id); CHECK_EQ(last_page_len->device.device_id, dev_id); + CHECK_EQ(q_rope_position->device.device_id, dev_id); + CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK_EQ(qo_indptr->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1); + last_page_len->dtype.lanes == 1 && q_rope_position->dtype.lanes == 1 && + k_rope_pos_offset->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1); CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && page_table_indptr->dtype.bits == last_page_len->dtype.bits && page_table_indptr->dtype.bits == qo_indptr->dtype.bits && page_table_indptr->dtype.code == page_table_values->dtype.code && page_table_indptr->dtype.code == last_page_len->dtype.code && + page_table_indptr->dtype.code == q_rope_position->dtype.code && + page_table_indptr->dtype.code == k_rope_pos_offset->dtype.code && page_table_indptr->dtype.code == qo_indptr->dtype.code); CHECK_EQ(pages->ndim, 5); @@ -274,6 +285,11 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(q_data->shape[2], nfeat); CHECK_EQ(output->shape[1], nhead_qo); CHECK_EQ(output->shape[2], nfeat); + CHECK_EQ(q_rope_position->ndim, 1); + CHECK_EQ(q_rope_position->shape[0], q_data->shape[0]); + + CHECK_EQ(k_rope_pos_offset->ndim, 1); + CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; @@ -286,13 +302,15 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q nhead_kv, page_size, nfeat, num_total_seqs, static_cast(pages->data), static_cast(page_table_values->data), static_cast(page_table_indptr->data), - static_cast(last_page_len->data)); + static_cast(last_page_len->data), + static_cast(k_rope_pos_offset->data)); cudaError_t status = _BatchPrefillWithPagedKVCacheWrapper( &batch_prefill_paged_kv_handlers[handler_id], static_cast(q_data->data), static_cast(qo_indptr->data), - cache, static_cast(output->data), + static_cast(q_rope_position->data), cache, + static_cast(output->data), /*lse=*/static_cast(lse->data), nhead_qo, /*causal=*/causal, RotaryMode(rotary_mode), /*allow_fp16_qk_reduction=*/false, rope_scale, rope_theta, 0); @@ -331,6 +349,8 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ DLTensor* page_table_indptr, // DLTensor* page_table_values, // DLTensor* last_page_len, // + DLTensor* k_rope_pos_offset, // + DLTensor* q_rope_position, // DLTensor* output, // DLTensor* lse, // int64_t rotary_mode = 0, // @@ -345,6 +365,10 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ << "The device of page_table_values matrix must be CUDA."; CHECK_EQ(last_page_len->device.device_type, kDLCUDA) << "The device of last_page_len matrix must be CUDA."; + CHECK_EQ(q_rope_position->device.device_type, kDLCUDA) + << "The device of q_rope_position matrix must be CUDA."; + CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) + << "The device of k_rope_pos_offset matrix must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; int32_t dev_id = q_data->device.device_id; @@ -352,16 +376,21 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ CHECK_EQ(page_table_indptr->device.device_id, dev_id); CHECK_EQ(page_table_values->device.device_id, dev_id); CHECK_EQ(last_page_len->device.device_id, dev_id); + CHECK_EQ(q_rope_position->device.device_id, dev_id); + CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1); + last_page_len->dtype.lanes == 1 && q_rope_position->dtype.lanes == 1 && + k_rope_pos_offset->dtype.lanes == 1); CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && page_table_indptr->dtype.bits == last_page_len->dtype.bits && page_table_indptr->dtype.code == page_table_values->dtype.code && - page_table_indptr->dtype.code == last_page_len->dtype.code); + page_table_indptr->dtype.code == last_page_len->dtype.code && + page_table_indptr->dtype.code == q_rope_position->dtype.code && + page_table_indptr->dtype.code == k_rope_pos_offset->dtype.code); CHECK_EQ(pages->ndim, 5); CHECK_EQ(pages->shape[1], 2); @@ -384,6 +413,11 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ int64_t nhead_qo = q_data->shape[1]; CHECK_EQ(output->shape[1], nhead_qo); CHECK_EQ(output->shape[2], nfeat); + CHECK_EQ(q_rope_position->ndim, 1); + CHECK_EQ(q_rope_position->shape[0], num_total_seqs); + + CHECK_EQ(k_rope_pos_offset->ndim, 1); + CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; @@ -396,10 +430,12 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ nhead_kv, page_size, nfeat, num_total_seqs, static_cast(pages->data), static_cast(page_table_values->data), static_cast(page_table_indptr->data), - static_cast(last_page_len->data)); + static_cast(last_page_len->data), + static_cast(k_rope_pos_offset->data)); cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &batch_decode_handlers[handler_id], static_cast(q_data->data), cache, + &batch_decode_handlers[handler_id], static_cast(q_data->data), + static_cast(q_rope_position->data), cache, static_cast(output->data), /*lse=*/static_cast(lse->data), nhead_qo, RotaryMode(rotary_mode), rope_scale, rope_theta, 0); @@ -449,9 +485,9 @@ void _FlashInferAttentionDecodeWithPagedKVCacheEndForward(int64_t handler_id) { template cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size, - const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, - bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + IdType* kv_indptr, IdType* q_rope_position_map, IdType* k_rope_pos_offset, DTypeOut* o, + float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, + const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { CHECK(lse != nullptr) << "The lse buffer must be provided"; @@ -464,24 +500,29 @@ cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( return BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, ROTARY_MODE, /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size, - num_kv_heads, rope_scale, rope_theta, stream); + handler, q, qo_indptr, k, v, kv_indptr, q_rope_position_map, + k_rope_pos_offset, o, lse, batch_size, num_kv_heads, rope_scale, + rope_theta, stream); })})}); return cudaSuccess; } -void _FlashInferAttentionPrefillWithRaggedKVCache(DLTensor* q_data, DLTensor* qo_indptr, - DLTensor* k_data, DLTensor* v_data, - DLTensor* kv_indptr, DLTensor* output, - DLTensor* lse, int64_t causal = 1, - int64_t rotary_mode = 0, double rope_scale = 1.0f, - double rope_theta = 1e4) { +void _FlashInferAttentionPrefillWithRaggedKVCache( + DLTensor* q_data, DLTensor* qo_indptr, DLTensor* k_data, DLTensor* v_data, DLTensor* kv_indptr, + DLTensor* q_rope_position_map, DLTensor* k_rope_pos_offset, DLTensor* output, DLTensor* lse, + int64_t causal = 1, int64_t rotary_mode = 0, double rope_scale = 1.0f, + double rope_theta = 1e4) { CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr must be CUDA."; CHECK_EQ(k_data->device.device_type, kDLCUDA) << "The device of k_data must be CUDA."; CHECK_EQ(v_data->device.device_type, kDLCUDA) << "The device of v_data must be CUDA."; CHECK_EQ(kv_indptr->device.device_type, kDLCUDA) << "The device of kv_indptr must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; + CHECK_EQ(lse->device.device_type, kDLCUDA) << "The lse of output must be CUDA."; + CHECK_EQ(q_rope_position_map->device.device_type, kDLCUDA) + << "The device of q_rope_position_map must be CUDA."; + CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) + << "The device of k_rope_pos_offset must be CUDA."; int dev_id = q_data->device.device_id; CHECK_EQ(qo_indptr->device.device_id, dev_id); @@ -490,15 +531,20 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(DLTensor* q_data, DLTensor* qo CHECK_EQ(kv_indptr->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); CHECK_EQ(lse->device.device_id, dev_id); + CHECK_EQ(q_rope_position_map->device.device_id, dev_id); + CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK(q_data->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1 && k_data->dtype.lanes == 1 && v_data->dtype.lanes == 1 && kv_indptr->dtype.lanes == 1 && output->dtype.lanes == 1 && - lse->dtype.lanes == 1); + lse->dtype.lanes == 1 && q_rope_position_map->dtype.lanes == 1 && + k_rope_pos_offset->dtype.lanes == 1); CHECK(q_data->dtype.bits == k_data->dtype.bits && q_data->dtype.code == v_data->dtype.code); CHECK(qo_indptr->dtype.bits == kv_indptr->dtype.bits); CHECK(lse->dtype.bits == 32); CHECK(q_data->dtype.code == k_data->dtype.code && q_data->dtype.code == v_data->dtype.code); CHECK(qo_indptr->dtype.code == kv_indptr->dtype.code); + CHECK(q_rope_position_map->dtype.code == kv_indptr->dtype.code); + CHECK(k_rope_pos_offset->dtype.code == kv_indptr->dtype.code); CHECK(lse->dtype.code == kDLFloat); CHECK_EQ(q_data->ndim, 3); // qo_nnz, nhead_qo, nfeat @@ -524,6 +570,11 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(DLTensor* q_data, DLTensor* qo int64_t batch_size = qo_indptr->shape[0] - 1; CHECK_EQ(kv_indptr->shape[0], batch_size + 1); + CHECK_EQ(q_rope_position_map->ndim, 1); + CHECK_EQ(q_rope_position_map->shape[0], q_data->shape[0]); + CHECK_EQ(k_rope_pos_offset->ndim, 1); + CHECK_EQ(k_rope_pos_offset->shape[0], batch_size); + DISPATCH_TVM_CUDA_DTYPE( q_data->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE( @@ -533,6 +584,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCache(DLTensor* q_data, DLTensor* qo &batch_prefill_ragged_kv_handler, static_cast(q_data->data), static_cast(qo_indptr->data), static_cast(k_data->data), static_cast(v_data->data), static_cast(kv_indptr->data), + static_cast(q_rope_position_map->data), + static_cast(k_rope_pos_offset->data), static_cast(output->data), /*lse=*/static_cast(lse->data), batch_size, nhead_qo, nhead_kv, nfeat, /*causal=*/bool(causal), QKVLayout::kNHD, RotaryMode(rotary_mode),