Skip to content

Commit

Permalink
Bugfix: bugfix to #322 (#325)
Browse files Browse the repository at this point in the history
Some last commits for bugfix are missing for #322.
  • Loading branch information
yzh119 authored Jun 21, 2024
1 parent 545b9ca commit da83cf5
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 65 deletions.
37 changes: 18 additions & 19 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) {
int64_t low = min_kv_chunk_size, high = 0;
int64_t batch_size = packed_qo_len_arr.size();
int64_t max_kv_len;
int64_t max_kv_len = 0;
for (const int64_t& kv_len : kv_len_arr) {
max_kv_len = std::max(max_kv_len, kv_len);
}
Expand Down Expand Up @@ -174,9 +174,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
new_batch_size = batch_size;
} else {
// compute max_num_pages_per_batch and new_batch_size
std::vector<IdType> page_indptr_h(batch_size + 1), num_pages(batch_size);
std::vector<IdType> num_pages(batch_size);
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
num_pages[batch_idx] = page_indptr_h[batch_idx + 1] - page_indptr_h[batch_idx];
num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx];
}
std::tie(max_num_pages_per_batch, new_batch_size) =
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages,
Expand Down Expand Up @@ -517,14 +517,16 @@ class BatchDecodeHandler {
};

template <typename IdType>
cudaError_t PrefillSplitQOKVIndptr(
bool& split_kv, uint32_t& split_max_batch_size, uint32_t& total_num_tiles_q,
uint32_t& new_batch_size, WarpLayout& warp_layout, uint32_t& kv_chunk_size,
uint32_t& total_num_rows, std::vector<IdType>& request_indices,
std::vector<IdType>& qo_tile_indices, std::vector<IdType>& kv_tile_indices,
std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr, IdType* qo_indptr_h,
IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, cudaStream_t stream = nullptr) {
cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_size,
uint32_t& total_num_tiles_q, uint32_t& new_batch_size,
WarpLayout& warp_layout, uint32_t& kv_chunk_size,
uint32_t& total_num_rows, std::vector<IdType>& request_indices,
std::vector<IdType>& qo_tile_indices,
std::vector<IdType>& kv_tile_indices,
std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size) {
request_indices.clear();
qo_tile_indices.clear();
kv_tile_indices.clear();
Expand All @@ -536,8 +538,6 @@ cudaError_t PrefillSplitQOKVIndptr(
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
total_num_rows = qo_indptr_h[batch_size];

bool has_kv_last_page_len = kv_last_page_len_h != nullptr;

// step 0: get the number of SMs
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -570,7 +570,7 @@ cudaError_t PrefillSplitQOKVIndptr(
// step 2: determine kv_chunk_size
std::tie(split_kv, kv_chunk_size, new_batch_size) =
PrefillBinarySearchKVChunkSize(max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr,
qo_chunk_size, /*min_kv_chunk_size=*/(128 / page_size));
qo_chunk_size, /*min_kv_chunk_size=*/(512 / page_size));

// step 3: split qo_indptr and kv_indptr
total_num_tiles_q = 0;
Expand Down Expand Up @@ -656,9 +656,8 @@ class BatchPrefillHandler {

template <typename DTypeOut, typename IdType>
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h,
IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size) {
IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand All @@ -672,8 +671,8 @@ class BatchPrefillHandler {
FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr(
split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, warp_layout_,
kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec,
kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h,
kv_last_page_len_h, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, stream_));
kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size));
const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout_);

if (IsCUDAGraphEnabled()) {
Expand Down
13 changes: 5 additions & 8 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
torch::Tensor empty_q_data) {
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(workspace_buffer);
Expand All @@ -33,7 +32,6 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_DIM(1, workspace_buffer);
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kCPU).to(torch::kInt32);

size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
Expand All @@ -43,9 +41,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
static_cast<int32_t*>(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
Expand Down Expand Up @@ -285,7 +282,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
/*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
batch_size, num_qo_heads, num_kv_heads, head_dim,
/*page_size=*/1);
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
cudaGetErrorString(status));
Expand Down
6 changes: 3 additions & 3 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor page_kv_indptr, torch::Tensor page_kv_last_page_len,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data);
torch::Tensor page_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned page_size, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
Expand Down
1 change: 0 additions & 1 deletion python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,6 @@ def begin_forward(
self._workspace_buffer,
self._qo_indptr_buf,
indptr,
last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
1 change: 0 additions & 1 deletion python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,6 @@ def begin_forward(
self._workspace_buffer,
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
Expand Down
7 changes: 3 additions & 4 deletions src/bench_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
size_t workspace_size_in_bytes = 128 * 1024 * 1024;
thrust::device_vector<char> buffer(workspace_size_in_bytes);

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size);
handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status =
Expand Down
6 changes: 2 additions & 4 deletions src/bench_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
thrust::device_vector<char> buffer(workspace_size_in_bytes);
cascade_handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
timer.start();
cudaError_t status = SinglePrefillWithKVCache(
Expand Down Expand Up @@ -305,8 +304,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
thrust::device_vector<char> buffer(workspace_size_in_bytes);
baseline_handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
timer.start();
cudaError_t status =
Expand Down
21 changes: 9 additions & 12 deletions src/test_batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
kv_last_page_len.data(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) {
auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices,
Expand Down Expand Up @@ -190,10 +189,9 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo
thrust::device_vector<int32_t> append_indptr_device(append_indptr);
thrust::device_vector<int32_t> kv_indptr_device(kv_indptr);

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
/*kv_last_page_len=*/nullptr, batch_size, num_qo_heads,
num_kv_heads, head_dim, /*page_size=*/1);
handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(),
kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1);

auto status = BatchPrefillWithRaggedKVCacheWrapper<T, T, int32_t>(
&handler, thrust::raw_pointer_cast(queries_device.data()),
Expand Down Expand Up @@ -321,8 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size);
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

auto status =
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(
Expand Down Expand Up @@ -416,10 +413,10 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
thrust::device_vector<char> buffer(workspace_size_in_bytes);

handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(),
kv_indptr.data(), kv_last_page_len.data(),
/*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size);
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
/*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim,
page_size);

auto status =
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(
Expand Down
16 changes: 8 additions & 8 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
thrust::device_vector<char> buffer_baseline(workspace_size_in_bytes),
buffer_cascade(workspace_size_in_bytes);

baseline_handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes,
qo_indptr_h.data(), kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(),
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
cascade_handler.BeginForward<T, int32_t>(
(void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes,
qo_indptr_h.data(), kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size);
baseline_handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer_baseline.data()),
workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_combined_h.data(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);
cascade_handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer_cascade.data()),
workspace_size_in_bytes, qo_indptr_h.data(),
kv_indptr_unique_h.data(), batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size);

cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<page_storage, kv_layout, T, T, int32_t>(
&baseline_handler, thrust::raw_pointer_cast(q_d.data()),
Expand Down
8 changes: 3 additions & 5 deletions src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q

void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr,
DLTensor* kv_last_page_len, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) {
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
int64_t page_size, TVMStreamHandle copy_stream) {
CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor";
size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8;
CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers;
Expand All @@ -290,8 +290,6 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
static_cast<void*>(workspace_buffer->data), workspace_size_in_bytes,
static_cast<dtype_idx*>(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx),
static_cast<dtype_idx*>(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx),
static_cast<dtype_idx*>(kv_last_page_len->data) +
kv_last_page_len->byte_offset / sizeof(dtype_idx),
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
if (status != cudaSuccess) {
LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status);
Expand Down Expand Up @@ -568,7 +566,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(
static_cast<void*>(workspace_buffer->data), workspace_size_in_bytes,
static_cast<dtype_idx*>(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx),
static_cast<dtype_idx*>(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx),
/*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
batch_size, num_qo_heads, num_kv_heads, head_dim,
/*page_size=*/1);
if (status != cudaSuccess) {
LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache BeginForward error "
Expand Down

0 comments on commit da83cf5

Please sign in to comment.