diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 59e383ef..3aa7c115 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -180,7 +180,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( } std::tie(max_num_pages_per_batch, new_batch_size) = PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages, - 128 / page_size); + std::max(128 / page_size, 1)); if (new_batch_size == batch_size && !enable_cuda_graph) { // do not use partition-kv kernel for short sequence, when not using CUDAGraph split_kv = false; @@ -564,9 +564,9 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); // step 2: determine kv_chunk_size - std::tie(split_kv, kv_chunk_size, new_batch_size) = - PrefillBinarySearchKVChunkSize(max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, - qo_chunk_size, /*min_kv_chunk_size=*/(128 / page_size)); + std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize( + max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr, qo_chunk_size, + /*min_kv_chunk_size=*/std::max((128 / page_size), 1)); // step 3: split qo_indptr and kv_indptr total_num_tiles_q = 0;