From c5e3437b74288938a78647be737cf4301b43ef24 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Wed, 17 Jan 2024 09:42:31 -0500 Subject: [PATCH 1/7] upd --- include/flashinfer/decode.cuh | 248 +++++++++++++-------------------- include/flashinfer/handler.cuh | 77 +++++----- include/flashinfer/page.cuh | 116 +++++---------- include/flashinfer/prefill.cuh | 76 +++++----- src/bench_batch_decode.cu | 3 +- src/bench_single_decode.cu | 2 +- src/test_batch_decode.cu | 2 +- src/test_single_decode.cu | 4 +- src/tvm_wrapper.cu | 4 +- 9 files changed, 222 insertions(+), 310 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index 2179ec9972..f4a0cfd059 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -27,6 +27,7 @@ #include #include +#include "cascade.cuh" #include "cp_async.cuh" #include "layout.cuh" #include "math.cuh" @@ -172,7 +173,7 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f /*! * \brief FlashAttention decoding cuda kernel with kv-cache for a single request * \tparam layout The layout of k/v matrices (NHD or HND) - * \tparam cooperative Whether to use cooperative kernel or not + * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not * \tparam rotary_mode The rotary mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension @@ -194,12 +195,12 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ -template __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, - float* __restrict__ tmp, + DTypeOut* __restrict__ tmp, tensor_info_t info, float sm_scale, float rope_rcp_scale, float rope_rcp_theta, uint32_t kv_chunk_size) { @@ -329,38 +330,15 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* // sync local state of all warps inside a threadblock sync_state(st_local, reinterpret_cast(smem), smem_md); + st_local.normalize(); - if constexpr (cooperative) { + if constexpr (partition_kv) { // update tmp buffer - st_local.o.store(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + tx * vec_size); - float* tmp_md = tmp + num_qo_heads * num_kv_chunks * head_dim; - *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2] = - make_float2(st_local.m, st_local.d); - grid.sync(); - - // sync global states - if (kv_chunk_idx == 0) { - state_t st_global; -#pragma unroll 2 - for (uint32_t iter = 0; iter < ceil_div(num_kv_chunks, bdz); ++iter) { - uint32_t kv_chunk_idx = iter * bdz + tz; - if (kv_chunk_idx < num_kv_chunks) { - float2 md = *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2]; - st_local.m = md.x; - st_local.d = md.y; - st_local.o.load(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + + st_local.o.cast_store(tmp + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); - st_global.merge(st_local); - } - } - block.sync(); - // sync local state of all warps inside a threadblock - sync_state(st_global, reinterpret_cast(smem), smem_md); - st_global.normalize(); - st_global.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); - } + float* tmp_lse = (float*)(tmp + num_kv_chunks * num_qo_heads * head_dim); + tmp_lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); } else { - st_local.normalize(); st_local.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); } } @@ -494,7 +472,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeIn* __restrict__ q, DTyp /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests - * \tparam cooperative Whether to use cooperative kernel or not + * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not * \tparam rotary_mode The rotary mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension @@ -516,13 +494,14 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeIn* __restrict__ q, DTyp * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ -template __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeIn* __restrict__ q, paged_kv_t paged_kv, - DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, - float rope_rcp_scale, float rope_rcp_theta) { + kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, + DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, + float rope_rcp_theta) { auto block = cg::this_thread_block(); sm_scale *= math::log2e; @@ -531,7 +510,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( const uint32_t kv_head_idx = blockIdx.y; const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; const uint32_t num_qo_heads = gridDim.y * bdy; - const uint32_t cur_chunk_start = cooperative ? paged_kv.chunk_start()[batch_idx] : 0U; + const uint32_t cur_chunk_start = partition_kv ? kv_partition_info.chunk_start_pos[batch_idx] : 0U; const uint32_t cur_page_indptr_begin = paged_kv.indptr[batch_idx], cur_page_indptr_end = paged_kv.indptr[batch_idx + 1]; const uint32_t cur_last_page_len = paged_kv.last_page_len[batch_idx]; @@ -540,7 +519,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( ? (cur_page_indptr_end - cur_page_indptr_begin - 1) * paged_kv.page_size + cur_last_page_len : 0; - const uint32_t seq_len = cooperative ? paged_kv.seq_lens_before_split()[batch_idx] : kv_chunk_len; + const uint32_t seq_len = + partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len; extern __shared__ uint8_t smem[]; DTypeIn* k_smem = (DTypeIn*)smem; @@ -562,19 +542,19 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix - if constexpr (cooperative) { + if constexpr (partition_kv) { q_vec = vec_apply_llama_rope( - q + (paged_kv.batch_idx_map()[batch_idx] * num_qo_heads + qo_head_idx) * head_dim, freq, - seq_len - 1); + q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim, + freq, seq_len - 1); } else { q_vec = vec_apply_llama_rope( q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, seq_len - 1); } } else { // do not apply rotary embedding to q matrix - if constexpr (cooperative) { + if constexpr (partition_kv) { q_vec.cast_load( - q + (paged_kv.batch_idx_map()[batch_idx] * num_qo_heads + qo_head_idx) * head_dim + + q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); } else { q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); @@ -697,48 +677,13 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // sync local state of all warps inside a threadblock sync_state(st, reinterpret_cast(smem), smem_md); + st.normalize(); - if constexpr (cooperative) { - auto grid = cg::this_grid(); - // update tmp buffer - st.o.store(tmp + (qo_head_idx * paged_kv.batch_size + batch_idx) * head_dim + tx * vec_size); - float* tmp_md = tmp + num_qo_heads * paged_kv.batch_size * head_dim; - *(float2*)&tmp_md[(qo_head_idx * paged_kv.batch_size + batch_idx) * 2] = - make_float2(st.m, st.d); - grid.sync(); - - // sync global states - const uint32_t cooperative_indptr_begin = paged_kv.cooperative_indptr()[batch_idx], - cooperative_indptr_end = paged_kv.cooperative_indptr()[batch_idx + 1]; - if (cooperative_indptr_begin < cooperative_indptr_end) { - state_t st_global; - const uint32_t num_pages = cooperative_indptr_end - cooperative_indptr_begin; -#pragma unroll 2 - for (uint32_t iter = 0; iter < ceil_div(num_pages, bdz); ++iter) { - uint32_t kv_chunk_idx = cooperative_indptr_begin + iter * bdz + tz; - if (kv_chunk_idx < cooperative_indptr_end) { - float2 md = *(float2*)&tmp_md[(qo_head_idx * paged_kv.batch_size + kv_chunk_idx) * 2]; - st.m = md.x; - st.d = md.y; - st.o.load(tmp + (qo_head_idx * paged_kv.batch_size + kv_chunk_idx) * head_dim + - tx * vec_size); - st_global.merge(st); - } - } - block.sync(); - // sync local state of all warps inside a threadblock - sync_state(st_global, reinterpret_cast(smem), smem_md); - st_global.normalize(); - st_global.o.cast_store( - o + (paged_kv.batch_idx_map()[batch_idx] * num_qo_heads + qo_head_idx) * head_dim + - tx * vec_size); - // write lse - if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx] = st_global.get_lse(); - } - } + if constexpr (partition_kv) { + st.o.cast_store(tmp + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + float* tmp_lse = (float*)(tmp + paged_kv.batch_size * num_qo_heads * head_dim); + tmp_lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse(); } else { - st.normalize(); st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); // write lse if (lse != nullptr) { @@ -766,11 +711,11 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo /*! * \brief Esitmate the temporary buffer size and the maximum grid size for the - * cooperative SingleDecodeWithKVCache kernel + * partiton-kv SingleDecodeWithKVCache kernel * \tparam DTypeIn A template type indicates the input data type * \tparam DTypeOut A template type indicates the output data type - * \param tmp_size The estimated temporary buffer size, return 0 if not use cooperative kernel - * \param max_grid_size The maximum grid size that can be used in a cooperative kernel + * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel + * \param max_grid_size The maximum grid size that can be used in a partition-kv kernel * \param num_qo_heads A integer indicates the number of heads of query and output * \param num_kv_heads A integer indicates the number of heads of key and value * \param seq_len A integer indicates the sequence length @@ -811,7 +756,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& 2U * bdy * bdz * sizeof(float); auto kernel = - SingleDecodeWithKVCacheKernel; int num_blocks_per_sm = 0; @@ -826,7 +771,8 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); uint32_t num_kv_chunks = ceil_div(seq_len, kv_chunk_size); - tmp_size = num_qo_heads * num_kv_chunks * (head_dim + 2); + tmp_size = num_qo_heads * num_kv_chunks * + (head_dim * sizeof(DTypeOut) + 2 * sizeof(float)); })})})}); } return cudaSuccess; @@ -855,7 +801,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, +cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, @@ -890,9 +836,9 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut head_dim * sizeof(DTypeIn) + 2U * bdy * bdz * sizeof(float); if (seq_len <= 256 || tmp == nullptr) { - // no need to use cooperative kernel + // no need to use partition-kv kernel auto kernel = - SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( @@ -913,9 +859,9 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { - // use cooperative kernel + // use partition-kv kernel auto kernel = - SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( @@ -932,7 +878,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); - dim3 nblks = dim3(ceil_div(seq_len, kv_chunk_size), num_kv_heads); + uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); + dim3 nblks = dim3(num_chunks, num_kv_heads); if (nblks.x == 0 || nblks.y == 0) { std::cerr << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")" << std::endl; @@ -949,8 +896,11 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut (void*)&rope_rcp_scale, (void*)&rope_rcp_theta, (void*)&kv_chunk_size}; - FLASHINFER_CUDA_CALL(cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, - args, smem_size, stream)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL( + MergeStates(tmp, (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o, + nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); } })})})}); return cudaSuccess; @@ -971,10 +921,10 @@ template cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( const uint32_t max_num_pages_per_batch, const uint32_t old_batch_size, const uint32_t page_size, IdType* old_indptr, IdType* old_last_page_len, IdType* new_indptr_d, - IdType* new_last_page_len_d, IdType* cooperative_indptr_d, IdType* batch_idx_map_d, + IdType* new_last_page_len_d, IdType* chunk_indptr_d, IdType* batch_idx_map_d, IdType* chunk_start_d, IdType* seq_lens_before_split_d, cudaStream_t stream = nullptr) { - std::vector new_page_indptr_h{0}, new_last_page_len_h, cooperative_indptr_h{0}, - batch_idx_map_h, chunk_start_h, seq_lens_before_split_h; + std::vector new_page_indptr_h{0}, new_last_page_len_h, chunk_indptr_h{0}, batch_idx_map_h, + chunk_start_pos_h, seq_lens_before_split_h; std::vector old_indptr_h(old_batch_size + 1), old_last_page_len_h(old_batch_size); if (is_device_ptr(old_indptr)) { @@ -991,31 +941,26 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( } for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) { - uint32_t cooperative_indptr_delta = + uint32_t num_chunks = ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch); - if (cooperative_indptr_delta == 0) { + chunk_indptr_h.push_back(chunk_indptr_h.back() + num_chunks); + if (num_chunks == 0) { new_page_indptr_h.push_back(old_indptr_h[batch_idx]); new_last_page_len_h.push_back(0); batch_idx_map_h.push_back(batch_idx); - cooperative_indptr_h.push_back(cooperative_indptr_h.back()); - chunk_start_h.push_back(0); + chunk_start_pos_h.push_back(0); seq_lens_before_split_h.push_back(0); } else { uint32_t seq_len_before_split = (old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx] - 1) * page_size + old_last_page_len_h[batch_idx]; - for (uint32_t j = 0; j < cooperative_indptr_delta; ++j) { - bool is_last = (j + 1) == cooperative_indptr_delta; + for (uint32_t j = 0; j < num_chunks; ++j) { + bool is_last = (j + 1) == num_chunks; new_page_indptr_h.push_back(min(old_indptr_h[batch_idx] + (j + 1) * max_num_pages_per_batch, old_indptr_h[batch_idx + 1])); new_last_page_len_h.push_back(is_last ? old_last_page_len_h[batch_idx] : page_size); batch_idx_map_h.push_back(batch_idx); - if (j == 0) { - cooperative_indptr_h.push_back(cooperative_indptr_h.back() + cooperative_indptr_delta); - } else { - cooperative_indptr_h.push_back(cooperative_indptr_h.back()); - } - chunk_start_h.push_back(j * max_num_pages_per_batch * page_size); + chunk_start_pos_h.push_back(j * max_num_pages_per_batch * page_size); seq_lens_before_split_h.push_back(seq_len_before_split); } } @@ -1027,14 +972,14 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( FLASHINFER_CUDA_CALL(cudaMemcpyAsync(new_last_page_len_d, new_last_page_len_h.data(), sizeof(IdType) * new_last_page_len_h.size(), cudaMemcpyHostToDevice, stream)); - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(cooperative_indptr_d, cooperative_indptr_h.data(), - sizeof(IdType) * cooperative_indptr_h.size(), + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(chunk_indptr_d, chunk_indptr_h.data(), + sizeof(IdType) * chunk_indptr_h.size(), cudaMemcpyHostToDevice, stream)); FLASHINFER_CUDA_CALL(cudaMemcpyAsync(batch_idx_map_d, batch_idx_map_h.data(), sizeof(IdType) * batch_idx_map_h.size(), cudaMemcpyHostToDevice, stream)); - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(chunk_start_d, chunk_start_h.data(), - sizeof(IdType) * chunk_start_h.size(), + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(chunk_start_d, chunk_start_pos_h.data(), + sizeof(IdType) * chunk_start_pos_h.size(), cudaMemcpyHostToDevice, stream)); FLASHINFER_CUDA_CALL(cudaMemcpyAsync(seq_lens_before_split_d, seq_lens_before_split_h.data(), sizeof(IdType) * seq_lens_before_split_h.size(), @@ -1085,13 +1030,13 @@ std::pair SplitPagedKVCacheBinarySearchMinNumPagePerBatch( /*! * \brief Estimate the temporary buffer size and the maximum grid size for the - * cooperative BatchDecodeWithPagedKVCache kernel + * partition-kv BatchDecodeWithPagedKVCache kernel * \tparam page_storage Whether to store indices or pointers of each active page * \tparam DTypeIn A template type indicates the input data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type - * \param tmp_size The estimated temporary buffer size, return 0 if not use cooperative kernel - * \param max_grid_size The maximum grid size that can be used in a cooperative kernel + * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel + * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel * \param max_num_pages_per_batch The maximum number of pages per batch * \param new_batch_size The new batch size after the split * \param paged_kv The paged kv cache data structure @@ -1123,10 +1068,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); - auto cooperative_kernel = - BatchDecodeWithPagedKVCacheKernel; + auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< + /*partition_kv=*/true, ROTARY_MODE, num_stages_smem, tile_size_per_bdx, vec_size, + bdx, bdy, bdz, page_storage, DTypeIn, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -1134,10 +1078,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, cooperative_kernel, num_threads, smem_size)); + &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); max_grid_size = num_blocks_per_sm * num_sm; if (batch_size * num_kv_heads >= max_grid_size) { - // do not use cooperative kernel + // do not use partition-kv kernel tmp_size = 0; new_batch_size = batch_size; } else { @@ -1157,10 +1101,11 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( SplitPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages, 512 / page_size); if (new_batch_size == batch_size) { - // do not use cooperative kernel for short sequence + // do not use partition-kv kernel for short sequence tmp_size = 0; } else { - tmp_size = num_qo_heads * new_batch_size * (head_dim + 2); + tmp_size = num_qo_heads * new_batch_size * + (head_dim * sizeof(DTypeOut) + 2 * sizeof(float)); } } })})}); @@ -1170,13 +1115,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( template cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, paged_kv_t paged_kv, DTypeOut* o, float* tmp, - float* lse, float rope_scale, float rope_theta, cudaStream_t stream) { + DTypeIn* q, paged_kv_t paged_kv, + kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, + float rope_scale, float rope_theta, cudaStream_t stream) { const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t batch_size = paged_kv.batch_size; + const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; @@ -1191,17 +1138,18 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); if (tmp == nullptr) { - // do not use cooperative kernel + // do not use partition-kv kernel dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = - BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, (void*)&paged_kv, + (void*)&kv_partition_info, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1210,20 +1158,16 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( (void*)&rope_rcp_theta}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { - // use cooperative kernel - if (paged_kv.cooperative_aux_info == nullptr) { - std::cerr << "cooperative_aux_info is not defined for cooperative BatchDecode kernel." - << std::endl; - abort(); - } - auto cooperative_kernel = - BatchDecodeWithPagedKVCacheKernel; + // use partition-kv kernel + auto partition_kv_kernel = + BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - cooperative_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, (void*)&paged_kv, + (void*)&kv_partition_info, (void*)&o, (void*)&tmp, (void*)&lse, @@ -1232,8 +1176,11 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( (void*)&rope_rcp_theta}; dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); - FLASHINFER_CUDA_CALL(cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks, nthrs, args, - smem_size, stream)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp, (float*)(tmp + batch_size * num_qo_heads * HEAD_DIM), kv_partition_info.chunk_indptr, + o, lse, kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); } return cudaSuccess; @@ -1260,7 +1207,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( template cudaError_t BatchDecodeWithPagedKVCache(DTypeIn* q, paged_kv_t paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, + kv_partition_info_t kv_partition_info, DTypeOut* o, + DTypeOut* tmp, float* lse, uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -1279,7 +1227,7 @@ cudaError_t BatchDecodeWithPagedKVCache(DTypeIn* q, head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { return BatchDecodeWithPagedKVCacheDispatched( - q, paged_kv, o, tmp, lse, rope_scale, rope_theta, stream); + q, paged_kv, kv_partition_info, o, tmp, lse, rope_scale, rope_theta, stream); })})}); return cudaSuccess; @@ -1329,13 +1277,11 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType } template -cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, float* tmp, - float* lse, uint32_t batch_size, uint32_t padded_kv_len, - uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, QKVLayout layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, - float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { +cudaError_t BatchDecodeWithPaddedKVCache( + DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, DTypeOut* tmp, float* lse, uint32_t batch_size, + uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + QKVLayout layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { if (num_qo_heads % num_kv_heads != 0) { std::cerr << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " << num_kv_heads << std::endl; diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index 77c847ae88..5b56855f10 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -30,7 +30,10 @@ namespace flashinfer { class BatchDecodeHandler { public: - float* GetTempFloatBuffer() const { return float_buffer_; } + template + DType* GetTempFloatBuffer() const { + return (DType*)float_buffer_; + } template IdType* GetNewIndPtr() const { return (IdType*)int_buffer_; @@ -38,32 +41,24 @@ class BatchDecodeHandler { template IdType* GetNewLastPageLen() const { if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + new_batch_size_ + 1; - } else { - return nullptr; - } - } - // cooperative_aux_info starts with cooperative_indptr - template - IdType* GetCooperativeAuxInfo() const { - if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1; + return ((IdType*)int_buffer_) + batch_size_after_partition_ + 1; } else { return nullptr; } } template - IdType* GetCooperativeIndPtr() const { + IdType* GetChunkIndPtr() const { if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1; + return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + 1; } else { return nullptr; } } template - IdType* GetBatchIndexMap() const { + IdType* GetBatchIdxMap() const { if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + 3 * new_batch_size_ + 2; + return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + + batch_size_before_partition_ + 1; } else { return nullptr; } @@ -71,15 +66,17 @@ class BatchDecodeHandler { template IdType* GetChunkStartPos() const { if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + 4 * new_batch_size_ + 2; + return ((IdType*)int_buffer_) + 3 * batch_size_after_partition_ + + batch_size_before_partition_ + 2; } else { return nullptr; } } template - IdType* GetSeqLengthsBeforeSplit() const { + IdType* GetSeqLengthsBeforePartition() const { if (int_buffer_ != nullptr) { - return ((IdType*)int_buffer_) + 5 * new_batch_size_ + 2; + return ((IdType*)int_buffer_) + 4 * batch_size_after_partition_ + + batch_size_before_partition_ + 2; } else { return nullptr; } @@ -89,22 +86,24 @@ class BatchDecodeHandler { cudaError_t BeginForward(IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, RotaryMode rotary_mode) { + batch_size_before_partition_ = batch_size; uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimation; FLASHINFER_CUDA_CALL(work_estimation_func( tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_)); - new_batch_size_ = new_batch_size; + batch_size_after_partition_ = new_batch_size; if (tmp_size > 0) { - FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, sizeof(float) * tmp_size, stream_)); - FLASHINFER_CUDA_CALL( - cudaMallocAsync(&int_buffer_, sizeof(IdType) * (6 * new_batch_size + 2), stream_)); + FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_)); + FLASHINFER_CUDA_CALL(cudaMallocAsync( + &int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), + stream_)); FLASHINFER_CUDA_CALL(SplitPagedCacheKVComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, - GetNewIndPtr(), GetNewLastPageLen(), GetCooperativeIndPtr(), - GetBatchIndexMap(), GetChunkStartPos(), - GetSeqLengthsBeforeSplit(), stream_)); + GetNewIndPtr(), GetNewLastPageLen(), GetChunkIndPtr(), + GetBatchIdxMap(), GetChunkStartPos(), + GetSeqLengthsBeforePartition(), stream_)); } forward_started_ = true; return cudaSuccess; @@ -112,7 +111,8 @@ class BatchDecodeHandler { cudaError_t EndForward() { forward_started_ = false; - new_batch_size_ = 0; + batch_size_before_partition_ = 0; + batch_size_after_partition_ = 0; if (float_buffer_ != nullptr) { FLASHINFER_CUDA_CALL(cudaFreeAsync(float_buffer_, stream_)); float_buffer_ = nullptr; @@ -126,14 +126,16 @@ class BatchDecodeHandler { bool IsForwardStarted() const { return forward_started_; } - uint32_t GetNewBatchSize() const { return new_batch_size_; } + uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; } + + uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; } cudaStream_t GetCUDAStream() const { return stream_; } void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } BatchDecodeHandler() - : new_batch_size_(0U), + : batch_size_after_partition_(0U), float_buffer_(nullptr), int_buffer_(nullptr), forward_started_(false), @@ -141,8 +143,9 @@ class BatchDecodeHandler { ~BatchDecodeHandler() { EndForward(); } private: - uint32_t new_batch_size_; - float* float_buffer_; + uint32_t batch_size_before_partition_; + uint32_t batch_size_after_partition_; + void* float_buffer_; void* int_buffer_; bool forward_started_; cudaStream_t stream_; @@ -253,14 +256,19 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { paged_kv_t new_paged_kv = paged_kv; - float* tmp = handler->GetTempFloatBuffer(); + kv_partition_info_t kv_partition_info; + DTypeOut* tmp = handler->GetTempFloatBuffer(); if (handler->IsForwardStarted()) { if (tmp != nullptr) { // create auxiliary information for cooperative kernels - new_paged_kv.batch_size = handler->GetNewBatchSize(); + new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition(); new_paged_kv.indptr = handler->GetNewIndPtr(); new_paged_kv.last_page_len = handler->GetNewLastPageLen(); - new_paged_kv.cooperative_aux_info = handler->GetCooperativeAuxInfo(); + kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition(); + kv_partition_info.chunk_indptr = handler->GetChunkIndPtr(); + kv_partition_info.batch_idx_map = handler->GetBatchIdxMap(); + kv_partition_info.chunk_start_pos = handler->GetChunkStartPos(); + kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition(); } } else { std::cerr << "Please call BatchDecodeHandler's BeginForward() before calling " @@ -269,7 +277,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp abort(); } return BatchDecodeWithPagedKVCache( - q, new_paged_kv, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, rope_theta, stream); + q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, + rope_theta, stream); } template +struct kv_partition_info_t { + uint32_t batch_size_before_partition; + IdType* chunk_indptr; + IdType* batch_idx_map; + IdType* chunk_start_pos; + IdType* seq_lens_before_partition; + + __host__ __device__ __forceinline__ kv_partition_info_t(uint32_t batch_size_before_partition, + IdType* chunk_indptr, + IdType* batch_idx_map, + IdType* chunk_start_pos, + IdType* seq_lens_before_partition) + : batch_size_before_partition(batch_size_before_partition), + chunk_indptr(chunk_indptr), + batch_idx_map(batch_idx_map), + chunk_start_pos(chunk_start_pos), + seq_lens_before_partition(seq_lens_before_partition) {} + + __host__ __device__ __forceinline__ kv_partition_info_t() + : batch_size_before_partition(0), + chunk_indptr(nullptr), + batch_idx_map(nullptr), + chunk_start_pos(nullptr), + seq_lens_before_partition(nullptr) {} +}; + /*! * \brief Paged key-value cache * \tparam page_storage Whether to store indices or pointers of each active page @@ -56,24 +86,6 @@ struct paged_kv_t { // [batch_size] The offset of the last page for each request in the batch IdType* last_page_len; - /* ------------ Auxliary Information Used in Cooperative Kernels ------------ */ - IdType* cooperative_aux_info; - __host__ __device__ __forceinline__ IdType* cooperative_indptr() const { - return cooperative_aux_info; - } - - __host__ __device__ __forceinline__ IdType* batch_idx_map() const { - return cooperative_aux_info + batch_size + 1; - } - - __host__ __device__ __forceinline__ IdType* chunk_start() const { - return cooperative_aux_info + 2 * batch_size + 1; - } - - __host__ __device__ __forceinline__ IdType* seq_lens_before_split() const { - return cooperative_aux_info + 3 * batch_size + 1; - } - /*! * \brief Construct an empty paged key-value cache */ @@ -86,11 +98,10 @@ struct paged_kv_t { indices(nullptr), ptrs(nullptr), indptr(nullptr), - last_page_len(nullptr), - cooperative_aux_info(nullptr) {} + last_page_len(nullptr) {} /*! - * \brief Construct a paged key-value cache for non-cooperative kernels + * \brief Construct a paged key-value cache * \param num_heads The number of heads * \param page_size The size of each page * \param head_dim The dimension of each head @@ -112,11 +123,10 @@ struct paged_kv_t { data(data), indices(indices), indptr(indptr), - last_page_len(last_page_len), - cooperative_aux_info(nullptr) {} + last_page_len(last_page_len) {} /*! - * \brief Construct a paged key-value cache for non-cooperative kernels + * \brief Construct a paged key-value cache * \param num_heads The number of heads * \param page_size The size of each page * \param head_dim The dimension of each head @@ -135,63 +145,7 @@ struct paged_kv_t { head_dim(head_dim), batch_size(batch_size), ptrs(ptrs), - indptr(indptr), - last_page_len(last_page_len), - cooperative_aux_info(nullptr) {} - - /*! - * \brief Construct a paged key-value cache with auxiliary information for cooperative kernels - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param data The flattened key-value cache - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the batch - * \param cooperative_aux_info The auxiliary information used in cooperative kernels - * \note This constructor should only be used when page_storage == kIndices - */ - __host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, - uint32_t head_dim, uint32_t batch_size, - DType* data, IdType* indices, IdType* indptr, - IdType* last_page_len, - IdType* cooperative_aux_info) - : num_heads(num_heads), - page_size(page_size), - head_dim(head_dim), - batch_size(batch_size), - data(data), - indices(indices), - indptr(indptr), - last_page_len(last_page_len), - cooperative_aux_info(cooperative_aux_info) {} - - /*! - * \brief Construct a paged key-value cache with auxiliary information for cooperative kernels - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param ptrs The array of pointers to each active page - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the batch - * \param cooperative_aux_info The auxiliary information used in cooperative kernels - * \note This constructor should only be used when page_storage == kIndices - */ - __host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, - uint32_t head_dim, uint32_t batch_size, - DType** ptrs, IdType* indptr, - IdType* last_page_len, - IdType* cooperative_aux_info) - : num_heads(num_heads), - page_size(page_size), - head_dim(head_dim), - batch_size(batch_size), - ptrs(ptrs), - indptr(indptr), - last_page_len(last_page_len), - cooperative_aux_info(cooperative_aux_info) {} + indptr(indptr) {} /*! * \brief Compute the offset of k element in the allocated buffer. diff --git a/include/flashinfer/prefill.cuh b/include/flashinfer/prefill.cuh index f8842c189a..3fd43da7cd 100644 --- a/include/flashinfer/prefill.cuh +++ b/include/flashinfer/prefill.cuh @@ -416,8 +416,8 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs *k_smem_offset_r -= num_frags_y * 2; } -template +template __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, @@ -434,7 +434,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_ kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; const bool out_of_boundary = - (causal ? (kv_idx > kv_len + q_idx - qo_len || (split_kv && kv_idx >= chunk_end)) + (causal ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) : kv_idx >= chunk_end); s_frag[fx][fz][reg_id] = out_of_boundary ? DTypeQKAccum(-5e4) : s_frag[fx][fz][reg_id]; } @@ -731,7 +731,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] /*! * \brief FlashAttention prefill CUDA kernel for a single request. - * \tparam split_kv Whether to split kv_len into chunks. + * \tparam partition_kv Whether to split kv_len into chunks. * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). * \tparam causal Whether to use causal attention. * \tparam layout The layout of the input tensor. @@ -746,7 +746,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param k The key tensor. * \param v The value tensor. * \param o The output tensor. - * \param tmp The temporary buffer (used when split_kv is true). + * \param tmp The temporary buffer (used when partition_kv is true). * \param lse The logsumexp value. * \param qkv_info The tensor info of the input tensor. * \param sm_scale The scale factor applied to the softmax score. @@ -755,9 +755,9 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template +template __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, @@ -771,9 +771,9 @@ __global__ void SinglePrefillWithKVCacheKernel( const uint32_t tx = threadIdx.x, ty = threadIdx.y; const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_size = split_kv ? ceil_div(kv_len, num_chunks) : kv_len; - const uint32_t chunk_start = split_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = split_kv ? min((chunk_idx + 1) * chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = partition_kv ? min((chunk_idx + 1) * chunk_size, kv_len) : kv_len; auto block = cg::this_thread_block(); constexpr uint32_t head_dim = num_frags_y * 16; @@ -803,11 +803,12 @@ __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size, (tx % 8) * num_elems_per_128b()); DTypeOut* o_ptr_base = - split_kv ? ((DTypeOut*)tmp) + chunk_idx * qkv_info.get_num_qo_heads() * head_dim + - qkv_info.get_qo_elem_offset(qo_idx_base * num_chunks, kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()) - : o + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size, - (tx % 8) * num_elems_per_128b()); + partition_kv + ? ((DTypeOut*)tmp) + chunk_idx * qkv_info.get_num_qo_heads() * head_dim + + qkv_info.get_qo_elem_offset(qo_idx_base * num_chunks, kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()) + : o + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size, + (tx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset(ty * num_frags_x * 16 + tx % 16, tx / 16); @@ -878,7 +879,7 @@ __global__ void SinglePrefillWithKVCacheKernel( // apply mask if (iter >= mask_iteration) { - mask_s( + mask_s( qo_idx_base, chunk_start + iter * 16 * num_frags_z, qo_len, kv_len, chunk_end, s_frag); } @@ -912,10 +913,10 @@ __global__ void SinglePrefillWithKVCacheKernel( // write back write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_idx_base, qo_len, - split_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride); + partition_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride); // write lse - if (lse != nullptr || split_kv) { + if (lse != nullptr || partition_kv) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll @@ -925,7 +926,7 @@ __global__ void SinglePrefillWithKVCacheKernel( const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); const uint32_t qo_idx = qo_idx_base + (tx / 4 + j * 8 + fx * 16) / group_size; if (qo_idx < qo_len) { - if constexpr (split_kv) { + if constexpr (partition_kv) { float* tmp_lse = (float*)(((DTypeOut*)tmp) + qo_len * num_chunks * num_qo_heads * head_dim); tmp_lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = @@ -963,7 +964,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( const tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / group_size)); - constexpr bool split_kv = false; + constexpr bool partition_kv = false; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); @@ -1065,7 +1066,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( // apply mask if (iter >= mask_iteration) { - mask_s( + mask_s( qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); } @@ -1142,7 +1143,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( paged_kv.last_page_len[request_idx]; const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / group_size)); - constexpr bool split_kv = false; + constexpr bool partition_kv = false; constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); @@ -1239,7 +1240,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( // apply mask if (iter >= mask_iteration) { - mask_s( + mask_s( qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); } @@ -1365,8 +1366,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - auto split_kv_kernel = SinglePrefillWithKVCacheKernel< - /*split_kv=*/true, GROUP_SIZE, CAUSAL, LAYOUT, ROTARY_MODE, + auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< + /*partition_kv=*/true, GROUP_SIZE, CAUSAL, LAYOUT, ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; tensor_info_t qkv_info(qo_len, kv_len, @@ -1374,14 +1375,15 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * head_dim * sizeof(DTypeIn); FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - split_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); + partition_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; int num_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, split_kv_kernel, num_threads, smem_size)); + &num_blocks_per_sm, partition_kv_kernel, num_threads, + smem_size)); uint32_t num_chunks = min((num_blocks_per_sm * num_sm) / (num_kv_heads * @@ -1450,20 +1452,20 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* SWITCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { constexpr uint32_t num_threads = num_warps * warp_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; - auto split_kv_kernel = - SinglePrefillWithKVCacheKernel; + auto partition_kv_kernel = + SinglePrefillWithKVCacheKernel; tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( - split_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; int num_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, split_kv_kernel, num_threads, smem_size)); + &num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size)); uint32_t num_chunks = min((num_blocks_per_sm * num_sm) / (num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta)), @@ -1472,7 +1474,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv auto kernel = - SinglePrefillWithKVCacheKernel; void* args[] = {(void*)&q, @@ -1506,7 +1508,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* dim3 nblks(ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta), num_chunks, num_kv_heads); dim3 nthrs(32, num_warps); FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)split_kv_kernel, nblks, nthrs, args, smem_size, stream)); + cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream)); const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; FLASHINFER_CUDA_CALL( MergeStates((DTypeOut*)tmp, diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index f0ebcd9fe3..efe78a2250 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -85,7 +85,8 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { } else { state.exec([&](nvbench::launch&) { cudaError_t status = BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q.data()), paged_kv, thrust::raw_pointer_cast(o.data()), nullptr, + thrust::raw_pointer_cast(q.data()), paged_kv, kv_partition_info_t(), + thrust::raw_pointer_cast(o.data()), nullptr, /*lse=*/nullptr, num_qo_heads, rotary_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); diff --git a/src/bench_single_decode.cu b/src/bench_single_decode.cu index 6738fe6fa9..70ba106298 100644 --- a/src/bench_single_decode.cu +++ b/src/bench_single_decode.cu @@ -36,7 +36,7 @@ void bench_flashinfer_single_decode(nvbench::state& state) { thrust::device_vector K(seq_len * num_kv_heads * head_dim); thrust::device_vector V(seq_len * num_kv_heads * head_dim); thrust::device_vector O(num_qo_heads * head_dim); - thrust::device_vector tmp(512 * num_qo_heads * head_dim); + thrust::device_vector tmp(16 * 1024 * 1024); // Provide throughput information: state.add_global_memory_reads( diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 6ae839b25a..c747e26704 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -102,7 +102,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si if (!cooperative) { // use non-cooperative kernel cudaError_t status = flashinfer::BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q_device.data()), paged_kv, + thrust::raw_pointer_cast(q_device.data()), paged_kv, kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, rotary_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); diff --git a/src/test_single_decode.cu b/src/test_single_decode.cu index 1432290409..af4fd98108 100644 --- a/src/test_single_decode.cu +++ b/src/test_single_decode.cu @@ -40,7 +40,7 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si thrust::device_vector K(K_host); thrust::device_vector V(V_host); thrust::device_vector O(O_host); - thrust::device_vector tmp(512 * num_qo_heads * head_dim); + thrust::device_vector tmp(16 * 1024 * 1024); std::vector o_ref_host; o_ref_host = cpu_reference::single_mha(Q_host, K_host, V_host, 1, seq_len, num_qo_heads, @@ -55,7 +55,7 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si << cudaGetErrorString(status); thrust::host_vector o_host = O; - thrust::host_vector tmp_host = tmp; + thrust::host_vector tmp_host = tmp; size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 79adde7cce..74d5b80340 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -166,8 +166,8 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DL q->dtype, dtype_in, {SWITCH_TVM_CUDA_DTYPE(o->dtype, dtype_out, { cudaError_t status = SingleDecodeWithKVCache( (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, - (float*)tmp->data, num_qo_heads, num_kv_heads, seq_len, head_dim, QKVLayout(qkv_layout), - RotaryMode(rotary_mode), rope_scale, rope_theta, 0); + (dtype_out*)tmp->data, num_qo_heads, num_kv_heads, seq_len, head_dim, + QKVLayout(qkv_layout), RotaryMode(rotary_mode), rope_scale, rope_theta, 0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } From e3d6fe5cfec1f6adfc213ae8f3486b4a19982241 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 02:10:30 -0500 Subject: [PATCH 2/7] bugfix --- include/flashinfer/handler.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index 5b56855f10..2e8305d942 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -58,7 +58,7 @@ class BatchDecodeHandler { IdType* GetBatchIdxMap() const { if (int_buffer_ != nullptr) { return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + - batch_size_before_partition_ + 1; + batch_size_before_partition_ + 2; } else { return nullptr; } From 29a35f250fd99f05aeb7aae8518650c16b54ae36 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 10:26:24 +0000 Subject: [PATCH 3/7] change parameters --- include/flashinfer/decode.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index f4a0cfd059..5578a08f26 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -733,7 +733,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr) { const uint32_t GROUP_SIZE = num_qo_heads / num_kv_heads; - if (seq_len <= 128U / uint32_t(std::sqrt(GROUP_SIZE))) { + if (seq_len <= 256U) { tmp_size = 0; } else { SWITCH_GQA_GROUP_SIZE( @@ -1099,7 +1099,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( } std::tie(max_num_pages_per_batch, new_batch_size) = SplitPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, - num_pages, 512 / page_size); + num_pages, 128 / page_size); if (new_batch_size == batch_size) { // do not use partition-kv kernel for short sequence tmp_size = 0; From 7ba8176eaf56d44ff7403efd6c83cea458937515 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 10:49:14 +0000 Subject: [PATCH 4/7] rename split to partition --- include/flashinfer/decode.cuh | 32 ++++++++++++++++---------------- include/flashinfer/handler.cuh | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index 5578a08f26..c893077f9b 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -907,7 +907,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut } /*! - * \brief Split Paged KV-Cache into multiple chunks on KV sequence length + * \brief Partition Paged KV-Cache into multiple chunks on KV sequence length * \tparam IdType A template type indicates the index data type * \param old_batch_size The batch size of the old Paged KV-Cache * \param old_page_indptr_h The host-side page indptr of the old Paged KV-Cache @@ -918,13 +918,13 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( +cudaError_t PartitionPagedCacheKVComputeAuxiliaryInfo( const uint32_t max_num_pages_per_batch, const uint32_t old_batch_size, const uint32_t page_size, IdType* old_indptr, IdType* old_last_page_len, IdType* new_indptr_d, IdType* new_last_page_len_d, IdType* chunk_indptr_d, IdType* batch_idx_map_d, - IdType* chunk_start_d, IdType* seq_lens_before_split_d, cudaStream_t stream = nullptr) { + IdType* chunk_start_d, IdType* seq_lens_before_partition_d, cudaStream_t stream = nullptr) { std::vector new_page_indptr_h{0}, new_last_page_len_h, chunk_indptr_h{0}, batch_idx_map_h, - chunk_start_pos_h, seq_lens_before_split_h; + chunk_start_pos_h, seq_lens_before_partition_h; std::vector old_indptr_h(old_batch_size + 1), old_last_page_len_h(old_batch_size); if (is_device_ptr(old_indptr)) { @@ -949,9 +949,9 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( new_last_page_len_h.push_back(0); batch_idx_map_h.push_back(batch_idx); chunk_start_pos_h.push_back(0); - seq_lens_before_split_h.push_back(0); + seq_lens_before_partition_h.push_back(0); } else { - uint32_t seq_len_before_split = + uint32_t seq_len_before_partition = (old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx] - 1) * page_size + old_last_page_len_h[batch_idx]; for (uint32_t j = 0; j < num_chunks; ++j) { @@ -961,7 +961,7 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( new_last_page_len_h.push_back(is_last ? old_last_page_len_h[batch_idx] : page_size); batch_idx_map_h.push_back(batch_idx); chunk_start_pos_h.push_back(j * max_num_pages_per_batch * page_size); - seq_lens_before_split_h.push_back(seq_len_before_split); + seq_lens_before_partition_h.push_back(seq_len_before_partition); } } } @@ -981,15 +981,15 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( FLASHINFER_CUDA_CALL(cudaMemcpyAsync(chunk_start_d, chunk_start_pos_h.data(), sizeof(IdType) * chunk_start_pos_h.size(), cudaMemcpyHostToDevice, stream)); - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(seq_lens_before_split_d, seq_lens_before_split_h.data(), - sizeof(IdType) * seq_lens_before_split_h.size(), - cudaMemcpyHostToDevice, stream)); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync( + seq_lens_before_partition_d, seq_lens_before_partition_h.data(), + sizeof(IdType) * seq_lens_before_partition_h.size(), cudaMemcpyHostToDevice, stream)); return cudaSuccess; } /*! * \brief Compute the maximum number of pages per batch and the new batch size - * after we split Paged KV-Cache into multiple chunks on KV sequence length + * after we partition Paged KV-Cache into multiple chunks on KV sequence length * dimension. * \tparam IdType A template type indicates the index data type * \param max_grid_size The maximum grid size of the kernel @@ -998,10 +998,10 @@ cudaError_t SplitPagedCacheKVComputeAuxiliaryInfo( * \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of * pages per batch, default to 1 * \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and - * the new batch size after the split. + * the new batch size after the partition. */ template -std::pair SplitPagedKVCacheBinarySearchMinNumPagePerBatch( +std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( const uint32_t max_grid_size, const uint32_t num_kv_heads, const std::vector& num_pages, const uint32_t min_num_pages_per_batch = 1) { uint32_t low = min_num_pages_per_batch, high = 0; @@ -1038,7 +1038,7 @@ std::pair SplitPagedKVCacheBinarySearchMinNumPagePerBatch( * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel * \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel * \param max_num_pages_per_batch The maximum number of pages per batch - * \param new_batch_size The new batch size after the split + * \param new_batch_size The new batch size after the partition * \param paged_kv The paged kv cache data structure * \param num_qo_heads A integer indicates the number of heads of query and output * \param rotary_mode The rotary mode @@ -1098,8 +1098,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( num_pages[batch_idx] = page_indptr_h[batch_idx + 1] - page_indptr_h[batch_idx]; } std::tie(max_num_pages_per_batch, new_batch_size) = - SplitPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, - num_pages, 128 / page_size); + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, + num_pages, 128 / page_size); if (new_batch_size == batch_size) { // do not use partition-kv kernel for short sequence tmp_size = 0; diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index 2e8305d942..be82b80de4 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -99,7 +99,7 @@ class BatchDecodeHandler { FLASHINFER_CUDA_CALL(cudaMallocAsync( &int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), stream_)); - FLASHINFER_CUDA_CALL(SplitPagedCacheKVComputeAuxiliaryInfo( + FLASHINFER_CUDA_CALL(PartitionKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, GetNewIndPtr(), GetNewLastPageLen(), GetChunkIndPtr(), GetBatchIdxMap(), GetChunkStartPos(), From 90fe6bdf072176942ceaa888f320f25aa3137873 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 10:53:06 +0000 Subject: [PATCH 5/7] bugfix --- include/flashinfer/handler.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index be82b80de4..971e40a324 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -99,7 +99,7 @@ class BatchDecodeHandler { FLASHINFER_CUDA_CALL(cudaMallocAsync( &int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), stream_)); - FLASHINFER_CUDA_CALL(PartitionKVCacheComputeAuxiliaryInfo( + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, GetNewIndPtr(), GetNewLastPageLen(), GetChunkIndPtr(), GetBatchIdxMap(), GetChunkStartPos(), From 8c7c45ef895a83ec962248b135cd880b7267a25f Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 10:54:35 +0000 Subject: [PATCH 6/7] bugfix --- include/flashinfer/decode.cuh | 2 +- include/flashinfer/handler.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index c893077f9b..cedc97f022 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -918,7 +918,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut * \return status Indicates whether CUDA calls are successful */ template -cudaError_t PartitionPagedCacheKVComputeAuxiliaryInfo( +cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( const uint32_t max_num_pages_per_batch, const uint32_t old_batch_size, const uint32_t page_size, IdType* old_indptr, IdType* old_last_page_len, IdType* new_indptr_d, IdType* new_last_page_len_d, IdType* chunk_indptr_d, IdType* batch_idx_map_d, diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index 971e40a324..a9983b6d48 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -99,7 +99,7 @@ class BatchDecodeHandler { FLASHINFER_CUDA_CALL(cudaMallocAsync( &int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), stream_)); - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + FLASHINFER_CUDA_CALL(PartitionPagedCacheKVComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, GetNewIndPtr(), GetNewLastPageLen(), GetChunkIndPtr(), GetBatchIdxMap(), GetChunkStartPos(), From c14cfd6a563c2f01633432d083107c4c62caac73 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Thu, 18 Jan 2024 10:55:13 +0000 Subject: [PATCH 7/7] bugfix --- include/flashinfer/handler.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/handler.cuh b/include/flashinfer/handler.cuh index a9983b6d48..971e40a324 100644 --- a/include/flashinfer/handler.cuh +++ b/include/flashinfer/handler.cuh @@ -99,7 +99,7 @@ class BatchDecodeHandler { FLASHINFER_CUDA_CALL(cudaMallocAsync( &int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), stream_)); - FLASHINFER_CUDA_CALL(PartitionPagedCacheKVComputeAuxiliaryInfo( + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, GetNewIndPtr(), GetNewLastPageLen(), GetChunkIndPtr(), GetBatchIdxMap(), GetChunkStartPos(),