Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 113 additions & 167 deletions include/flashinfer/decode.cuh

Large diffs are not rendered by default.

79 changes: 44 additions & 35 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,56 +30,53 @@ namespace flashinfer {

class BatchDecodeHandler {
public:
float* GetTempFloatBuffer() const { return float_buffer_; }
template <typename DType>
DType* GetTempFloatBuffer() const {
return (DType*)float_buffer_;
}
template <typename IdType>
IdType* GetNewIndPtr() const {
return (IdType*)int_buffer_;
}
template <typename IdType>
IdType* GetNewLastPageLen() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + new_batch_size_ + 1;
} else {
return nullptr;
}
}
// cooperative_aux_info starts with cooperative_indptr
template <typename IdType>
IdType* GetCooperativeAuxInfo() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1;
return ((IdType*)int_buffer_) + batch_size_after_partition_ + 1;
} else {
return nullptr;
}
}
template <typename IdType>
IdType* GetCooperativeIndPtr() const {
IdType* GetChunkIndPtr() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1;
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + 1;
} else {
return nullptr;
}
}
template <typename IdType>
IdType* GetBatchIndexMap() const {
IdType* GetBatchIdxMap() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 3 * new_batch_size_ + 2;
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
}
template <typename IdType>
IdType* GetChunkStartPos() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 4 * new_batch_size_ + 2;
return ((IdType*)int_buffer_) + 3 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
}
template <typename IdType>
IdType* GetSeqLengthsBeforeSplit() const {
IdType* GetSeqLengthsBeforePartition() const {
if (int_buffer_ != nullptr) {
return ((IdType*)int_buffer_) + 5 * new_batch_size_ + 2;
return ((IdType*)int_buffer_) + 4 * batch_size_after_partition_ +
batch_size_before_partition_ + 2;
} else {
return nullptr;
}
Expand All @@ -89,30 +86,33 @@ class BatchDecodeHandler {
cudaError_t BeginForward(IdType* indptr, IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, RotaryMode rotary_mode) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, DTypeIn, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
new_batch_size_ = new_batch_size;
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, sizeof(float) * tmp_size, stream_));
FLASHINFER_CUDA_CALL(
cudaMallocAsync(&int_buffer_, sizeof(IdType) * (6 * new_batch_size + 2), stream_));
FLASHINFER_CUDA_CALL(SplitPagedCacheKVComputeAuxiliaryInfo(
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_));
FLASHINFER_CUDA_CALL(cudaMallocAsync(
&int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2),
stream_));
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetCooperativeIndPtr<IdType>(),
GetBatchIndexMap<IdType>(), GetChunkStartPos<IdType>(),
GetSeqLengthsBeforeSplit<IdType>(), stream_));
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr<IdType>(),
GetBatchIdxMap<IdType>(), GetChunkStartPos<IdType>(),
GetSeqLengthsBeforePartition<IdType>(), stream_));
}
forward_started_ = true;
return cudaSuccess;
}

cudaError_t EndForward() {
forward_started_ = false;
new_batch_size_ = 0;
batch_size_before_partition_ = 0;
batch_size_after_partition_ = 0;
if (float_buffer_ != nullptr) {
FLASHINFER_CUDA_CALL(cudaFreeAsync(float_buffer_, stream_));
float_buffer_ = nullptr;
Expand All @@ -126,23 +126,26 @@ class BatchDecodeHandler {

bool IsForwardStarted() const { return forward_started_; }

uint32_t GetNewBatchSize() const { return new_batch_size_; }
uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; }

uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; }

cudaStream_t GetCUDAStream() const { return stream_; }

void SetCUDAStream(cudaStream_t stream) { stream_ = stream; }

BatchDecodeHandler()
: new_batch_size_(0U),
: batch_size_after_partition_(0U),
float_buffer_(nullptr),
int_buffer_(nullptr),
forward_started_(false),
stream_(nullptr) {}
~BatchDecodeHandler() { EndForward(); }

private:
uint32_t new_batch_size_;
float* float_buffer_;
uint32_t batch_size_before_partition_;
uint32_t batch_size_after_partition_;
void* float_buffer_;
void* int_buffer_;
bool forward_started_;
cudaStream_t stream_;
Expand Down Expand Up @@ -253,14 +256,19 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, DTypeIn, IdType> new_paged_kv = paged_kv;
float* tmp = handler->GetTempFloatBuffer();
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
if (handler->IsForwardStarted()) {
if (tmp != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetNewBatchSize();
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
new_paged_kv.cooperative_aux_info = handler->GetCooperativeAuxInfo<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::cerr << "Please call BatchDecodeHandler's BeginForward() before calling "
Expand All @@ -269,7 +277,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
abort();
}
return BatchDecodeWithPagedKVCache<page_storage, DTypeIn, DTypeOut, IdType>(
q, new_paged_kv, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, rope_theta, stream);
q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
rope_theta, stream);
}

template <PageStorage page_storage, uint32_t GROUP_SIZE, uint32_t HEAD_DIM, RotaryMode ROTARY_MODE,
Expand Down
116 changes: 35 additions & 81 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ enum class PageStorage {
kPointer = 1U, // Store the pointers to each active page.
};

/*!
* \brief The auxiliary information about kv sequence partitioning
*/
template <typename IdType>
struct kv_partition_info_t {
uint32_t batch_size_before_partition;
IdType* chunk_indptr;
IdType* batch_idx_map;
IdType* chunk_start_pos;
IdType* seq_lens_before_partition;

__host__ __device__ __forceinline__ kv_partition_info_t(uint32_t batch_size_before_partition,
IdType* chunk_indptr,
IdType* batch_idx_map,
IdType* chunk_start_pos,
IdType* seq_lens_before_partition)
: batch_size_before_partition(batch_size_before_partition),
chunk_indptr(chunk_indptr),
batch_idx_map(batch_idx_map),
chunk_start_pos(chunk_start_pos),
seq_lens_before_partition(seq_lens_before_partition) {}

__host__ __device__ __forceinline__ kv_partition_info_t()
: batch_size_before_partition(0),
chunk_indptr(nullptr),
batch_idx_map(nullptr),
chunk_start_pos(nullptr),
seq_lens_before_partition(nullptr) {}
};

/*!
* \brief Paged key-value cache
* \tparam page_storage Whether to store indices or pointers of each active page
Expand Down Expand Up @@ -56,24 +86,6 @@ struct paged_kv_t {
// [batch_size] The offset of the last page for each request in the batch
IdType* last_page_len;

/* ------------ Auxliary Information Used in Cooperative Kernels ------------ */
IdType* cooperative_aux_info;
__host__ __device__ __forceinline__ IdType* cooperative_indptr() const {
return cooperative_aux_info;
}

__host__ __device__ __forceinline__ IdType* batch_idx_map() const {
return cooperative_aux_info + batch_size + 1;
}

__host__ __device__ __forceinline__ IdType* chunk_start() const {
return cooperative_aux_info + 2 * batch_size + 1;
}

__host__ __device__ __forceinline__ IdType* seq_lens_before_split() const {
return cooperative_aux_info + 3 * batch_size + 1;
}

/*!
* \brief Construct an empty paged key-value cache
*/
Expand All @@ -86,11 +98,10 @@ struct paged_kv_t {
indices(nullptr),
ptrs(nullptr),
indptr(nullptr),
last_page_len(nullptr),
cooperative_aux_info(nullptr) {}
last_page_len(nullptr) {}

/*!
* \brief Construct a paged key-value cache for non-cooperative kernels
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
Expand All @@ -112,11 +123,10 @@ struct paged_kv_t {
data(data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
cooperative_aux_info(nullptr) {}
last_page_len(last_page_len) {}

/*!
* \brief Construct a paged key-value cache for non-cooperative kernels
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
Expand All @@ -135,63 +145,7 @@ struct paged_kv_t {
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
indptr(indptr),
last_page_len(last_page_len),
cooperative_aux_info(nullptr) {}

/*!
* \brief Construct a paged key-value cache with auxiliary information for cooperative kernels
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param data The flattened key-value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param cooperative_aux_info The auxiliary information used in cooperative kernels
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType* data, IdType* indices, IdType* indptr,
IdType* last_page_len,
IdType* cooperative_aux_info)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
data(data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
cooperative_aux_info(cooperative_aux_info) {}

/*!
* \brief Construct a paged key-value cache with auxiliary information for cooperative kernels
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param ptrs The array of pointers to each active page
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param cooperative_aux_info The auxiliary information used in cooperative kernels
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
uint32_t head_dim, uint32_t batch_size,
DType** ptrs, IdType* indptr,
IdType* last_page_len,
IdType* cooperative_aux_info)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
indptr(indptr),
last_page_len(last_page_len),
cooperative_aux_info(cooperative_aux_info) {}
indptr(indptr) {}

/*!
* \brief Compute the offset of k element in the allocated buffer.
Expand Down
Loading