Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rafactor: move gqa_group_size from template parameter to input arguments #301

Merged
merged 19 commits into from
Jun 15, 2024
279 changes: 133 additions & 146 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ set(FLASHINFER_FASTDIV_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
Expand Down
376 changes: 190 additions & 186 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

220 changes: 112 additions & 108 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -297,121 +297,125 @@ class BatchDecodeHandler {

bool* GetBlockValidMask() const { return block_valid_mask_; }

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ,
typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t page_size) {
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t page_size) {
batch_size_before_partition_ = batch_size;
uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size,
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
batch_size_after_partition_ = new_batch_size;
if (IsCUDAGraphEnabled()) {
if (batch_size != fixed_batch_size_) {
std::ostringstream err_msg;
err_msg << "The running batch size " << batch_size
<< " is not compatible with the fixed batch size " << fixed_batch_size_
<< " initialized for CUDAGraph";
throw std::runtime_error(err_msg.str());
}
size_t padded_batch_size = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<void>(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size,
batch_size, indptr, num_qo_heads, page_size,
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
batch_size_after_partition_ = new_batch_size;
if (IsCUDAGraphEnabled()) {
if (batch_size != fixed_batch_size_) {
std::ostringstream err_msg;
err_msg << "The running batch size " << batch_size
<< " is not compatible with the fixed batch size " << fixed_batch_size_
<< " initialized for CUDAGraph";
throw std::runtime_error(err_msg.str());
}
size_t padded_batch_size = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
} else {
block_valid_mask_ = nullptr;
padded_batch_size_ = batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
block_valid_mask_ = nullptr;
padded_batch_size_ = batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
block_valid_mask_ = nullptr;
// do not pad the batch size when not using CUDAGraph
padded_batch_size_ = batch_size_after_partition_;
if (tmp_size > 0) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
tmp_s_ = (char*)tmp_v_ +
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_,
/*block_valid_mask_h=*/nullptr,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
// do not pad the batch size when not using CUDAGraph
padded_batch_size_ = batch_size_after_partition_;
if (tmp_size > 0) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
tmp_s_ = (char*)tmp_v_ +
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>(
(batch_size_before_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_,
/*block_valid_mask_h=*/nullptr,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
}
}
}
});
forward_started_ = true;
return cudaSuccess;
}
Expand Down
Loading