Skip to content

Commit 29733e7

Browse files
authored
[Refactor] Use two kernels instead of CUDA cooperative kernel for batch/single decode (#72)
In our initial design we use CUDA cooperative kernels and grid synchronization feature for cross threadblock reduction. Although it's slightly faster for multi-head attention without grouping (GQA), we found there are two issues with this implementation: 1. The kernel scheduling for cross-threadblock merging (the code after `grid.sync`) is sub-optimal, the merging time would be the bottleneck when the number of chunks to merge is huge (e.g. in GQA). 2. Not all hardware has grid synchronization feature (only NVIDIA & AMD GPUs have such features AFAIK), which makes FlashInfer implementation hard to generalize to other GPUs such as Metal. In this PR we stop using CUDA cooperative kernels features and using the combo two kernels (one for decode, another for merge) instead, it harms performance a little bit for some shapes but I believe it's beneficial for longer-term development and maintainance.
1 parent ebd067a commit 29733e7

File tree

9 files changed

+239
-327
lines changed

9 files changed

+239
-327
lines changed

include/flashinfer/decode.cuh

Lines changed: 113 additions & 167 deletions
Large diffs are not rendered by default.

include/flashinfer/handler.cuh

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,56 +30,53 @@ namespace flashinfer {
3030

3131
class BatchDecodeHandler {
3232
public:
33-
float* GetTempFloatBuffer() const { return float_buffer_; }
33+
template <typename DType>
34+
DType* GetTempFloatBuffer() const {
35+
return (DType*)float_buffer_;
36+
}
3437
template <typename IdType>
3538
IdType* GetNewIndPtr() const {
3639
return (IdType*)int_buffer_;
3740
}
3841
template <typename IdType>
3942
IdType* GetNewLastPageLen() const {
4043
if (int_buffer_ != nullptr) {
41-
return ((IdType*)int_buffer_) + new_batch_size_ + 1;
42-
} else {
43-
return nullptr;
44-
}
45-
}
46-
// cooperative_aux_info starts with cooperative_indptr
47-
template <typename IdType>
48-
IdType* GetCooperativeAuxInfo() const {
49-
if (int_buffer_ != nullptr) {
50-
return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1;
44+
return ((IdType*)int_buffer_) + batch_size_after_partition_ + 1;
5145
} else {
5246
return nullptr;
5347
}
5448
}
5549
template <typename IdType>
56-
IdType* GetCooperativeIndPtr() const {
50+
IdType* GetChunkIndPtr() const {
5751
if (int_buffer_ != nullptr) {
58-
return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1;
52+
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + 1;
5953
} else {
6054
return nullptr;
6155
}
6256
}
6357
template <typename IdType>
64-
IdType* GetBatchIndexMap() const {
58+
IdType* GetBatchIdxMap() const {
6559
if (int_buffer_ != nullptr) {
66-
return ((IdType*)int_buffer_) + 3 * new_batch_size_ + 2;
60+
return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ +
61+
batch_size_before_partition_ + 2;
6762
} else {
6863
return nullptr;
6964
}
7065
}
7166
template <typename IdType>
7267
IdType* GetChunkStartPos() const {
7368
if (int_buffer_ != nullptr) {
74-
return ((IdType*)int_buffer_) + 4 * new_batch_size_ + 2;
69+
return ((IdType*)int_buffer_) + 3 * batch_size_after_partition_ +
70+
batch_size_before_partition_ + 2;
7571
} else {
7672
return nullptr;
7773
}
7874
}
7975
template <typename IdType>
80-
IdType* GetSeqLengthsBeforeSplit() const {
76+
IdType* GetSeqLengthsBeforePartition() const {
8177
if (int_buffer_ != nullptr) {
82-
return ((IdType*)int_buffer_) + 5 * new_batch_size_ + 2;
78+
return ((IdType*)int_buffer_) + 4 * batch_size_after_partition_ +
79+
batch_size_before_partition_ + 2;
8380
} else {
8481
return nullptr;
8582
}
@@ -89,30 +86,33 @@ class BatchDecodeHandler {
8986
cudaError_t BeginForward(IdType* indptr, IdType* last_page_len, uint32_t batch_size,
9087
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
9188
uint32_t page_size, RotaryMode rotary_mode) {
89+
batch_size_before_partition_ = batch_size;
9290
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
9391
auto work_estimation_func =
9492
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, DTypeIn, DTypeOut, IdType>;
9593
FLASHINFER_CUDA_CALL(work_estimation_func(
9694
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
9795
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
98-
new_batch_size_ = new_batch_size;
96+
batch_size_after_partition_ = new_batch_size;
9997
if (tmp_size > 0) {
100-
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, sizeof(float) * tmp_size, stream_));
101-
FLASHINFER_CUDA_CALL(
102-
cudaMallocAsync(&int_buffer_, sizeof(IdType) * (6 * new_batch_size + 2), stream_));
103-
FLASHINFER_CUDA_CALL(SplitPagedCacheKVComputeAuxiliaryInfo(
98+
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_));
99+
FLASHINFER_CUDA_CALL(cudaMallocAsync(
100+
&int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2),
101+
stream_));
102+
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
104103
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
105-
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetCooperativeIndPtr<IdType>(),
106-
GetBatchIndexMap<IdType>(), GetChunkStartPos<IdType>(),
107-
GetSeqLengthsBeforeSplit<IdType>(), stream_));
104+
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr<IdType>(),
105+
GetBatchIdxMap<IdType>(), GetChunkStartPos<IdType>(),
106+
GetSeqLengthsBeforePartition<IdType>(), stream_));
108107
}
109108
forward_started_ = true;
110109
return cudaSuccess;
111110
}
112111

113112
cudaError_t EndForward() {
114113
forward_started_ = false;
115-
new_batch_size_ = 0;
114+
batch_size_before_partition_ = 0;
115+
batch_size_after_partition_ = 0;
116116
if (float_buffer_ != nullptr) {
117117
FLASHINFER_CUDA_CALL(cudaFreeAsync(float_buffer_, stream_));
118118
float_buffer_ = nullptr;
@@ -126,23 +126,26 @@ class BatchDecodeHandler {
126126

127127
bool IsForwardStarted() const { return forward_started_; }
128128

129-
uint32_t GetNewBatchSize() const { return new_batch_size_; }
129+
uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; }
130+
131+
uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; }
130132

131133
cudaStream_t GetCUDAStream() const { return stream_; }
132134

133135
void SetCUDAStream(cudaStream_t stream) { stream_ = stream; }
134136

135137
BatchDecodeHandler()
136-
: new_batch_size_(0U),
138+
: batch_size_after_partition_(0U),
137139
float_buffer_(nullptr),
138140
int_buffer_(nullptr),
139141
forward_started_(false),
140142
stream_(nullptr) {}
141143
~BatchDecodeHandler() { EndForward(); }
142144

143145
private:
144-
uint32_t new_batch_size_;
145-
float* float_buffer_;
146+
uint32_t batch_size_before_partition_;
147+
uint32_t batch_size_after_partition_;
148+
void* float_buffer_;
146149
void* int_buffer_;
147150
bool forward_started_;
148151
cudaStream_t stream_;
@@ -253,14 +256,19 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
253256
float rope_scale = 1.f, float rope_theta = 1e4,
254257
cudaStream_t stream = nullptr) {
255258
paged_kv_t<page_storage, DTypeIn, IdType> new_paged_kv = paged_kv;
256-
float* tmp = handler->GetTempFloatBuffer();
259+
kv_partition_info_t<IdType> kv_partition_info;
260+
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
257261
if (handler->IsForwardStarted()) {
258262
if (tmp != nullptr) {
259263
// create auxiliary information for cooperative kernels
260-
new_paged_kv.batch_size = handler->GetNewBatchSize();
264+
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
261265
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
262266
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
263-
new_paged_kv.cooperative_aux_info = handler->GetCooperativeAuxInfo<IdType>();
267+
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
268+
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
269+
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
270+
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
271+
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
264272
}
265273
} else {
266274
std::cerr << "Please call BatchDecodeHandler's BeginForward() before calling "
@@ -269,7 +277,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
269277
abort();
270278
}
271279
return BatchDecodeWithPagedKVCache<page_storage, DTypeIn, DTypeOut, IdType>(
272-
q, new_paged_kv, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, rope_theta, stream);
280+
q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
281+
rope_theta, stream);
273282
}
274283

275284
template <PageStorage page_storage, uint32_t GROUP_SIZE, uint32_t HEAD_DIM, RotaryMode ROTARY_MODE,

include/flashinfer/page.cuh

Lines changed: 35 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,36 @@ enum class PageStorage {
2929
kPointer = 1U, // Store the pointers to each active page.
3030
};
3131

32+
/*!
33+
* \brief The auxiliary information about kv sequence partitioning
34+
*/
35+
template <typename IdType>
36+
struct kv_partition_info_t {
37+
uint32_t batch_size_before_partition;
38+
IdType* chunk_indptr;
39+
IdType* batch_idx_map;
40+
IdType* chunk_start_pos;
41+
IdType* seq_lens_before_partition;
42+
43+
__host__ __device__ __forceinline__ kv_partition_info_t(uint32_t batch_size_before_partition,
44+
IdType* chunk_indptr,
45+
IdType* batch_idx_map,
46+
IdType* chunk_start_pos,
47+
IdType* seq_lens_before_partition)
48+
: batch_size_before_partition(batch_size_before_partition),
49+
chunk_indptr(chunk_indptr),
50+
batch_idx_map(batch_idx_map),
51+
chunk_start_pos(chunk_start_pos),
52+
seq_lens_before_partition(seq_lens_before_partition) {}
53+
54+
__host__ __device__ __forceinline__ kv_partition_info_t()
55+
: batch_size_before_partition(0),
56+
chunk_indptr(nullptr),
57+
batch_idx_map(nullptr),
58+
chunk_start_pos(nullptr),
59+
seq_lens_before_partition(nullptr) {}
60+
};
61+
3262
/*!
3363
* \brief Paged key-value cache
3464
* \tparam page_storage Whether to store indices or pointers of each active page
@@ -56,24 +86,6 @@ struct paged_kv_t {
5686
// [batch_size] The offset of the last page for each request in the batch
5787
IdType* last_page_len;
5888

59-
/* ------------ Auxliary Information Used in Cooperative Kernels ------------ */
60-
IdType* cooperative_aux_info;
61-
__host__ __device__ __forceinline__ IdType* cooperative_indptr() const {
62-
return cooperative_aux_info;
63-
}
64-
65-
__host__ __device__ __forceinline__ IdType* batch_idx_map() const {
66-
return cooperative_aux_info + batch_size + 1;
67-
}
68-
69-
__host__ __device__ __forceinline__ IdType* chunk_start() const {
70-
return cooperative_aux_info + 2 * batch_size + 1;
71-
}
72-
73-
__host__ __device__ __forceinline__ IdType* seq_lens_before_split() const {
74-
return cooperative_aux_info + 3 * batch_size + 1;
75-
}
76-
7789
/*!
7890
* \brief Construct an empty paged key-value cache
7991
*/
@@ -86,11 +98,10 @@ struct paged_kv_t {
8698
indices(nullptr),
8799
ptrs(nullptr),
88100
indptr(nullptr),
89-
last_page_len(nullptr),
90-
cooperative_aux_info(nullptr) {}
101+
last_page_len(nullptr) {}
91102

92103
/*!
93-
* \brief Construct a paged key-value cache for non-cooperative kernels
104+
* \brief Construct a paged key-value cache
94105
* \param num_heads The number of heads
95106
* \param page_size The size of each page
96107
* \param head_dim The dimension of each head
@@ -112,11 +123,10 @@ struct paged_kv_t {
112123
data(data),
113124
indices(indices),
114125
indptr(indptr),
115-
last_page_len(last_page_len),
116-
cooperative_aux_info(nullptr) {}
126+
last_page_len(last_page_len) {}
117127

118128
/*!
119-
* \brief Construct a paged key-value cache for non-cooperative kernels
129+
* \brief Construct a paged key-value cache
120130
* \param num_heads The number of heads
121131
* \param page_size The size of each page
122132
* \param head_dim The dimension of each head
@@ -135,63 +145,7 @@ struct paged_kv_t {
135145
head_dim(head_dim),
136146
batch_size(batch_size),
137147
ptrs(ptrs),
138-
indptr(indptr),
139-
last_page_len(last_page_len),
140-
cooperative_aux_info(nullptr) {}
141-
142-
/*!
143-
* \brief Construct a paged key-value cache with auxiliary information for cooperative kernels
144-
* \param num_heads The number of heads
145-
* \param page_size The size of each page
146-
* \param head_dim The dimension of each head
147-
* \param batch_size The batch size
148-
* \param data The flattened key-value cache
149-
* \param indices The page indices array
150-
* \param indptr The page indptr array
151-
* \param last_page_len The offset of the last page for each request in the batch
152-
* \param cooperative_aux_info The auxiliary information used in cooperative kernels
153-
* \note This constructor should only be used when page_storage == kIndices
154-
*/
155-
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
156-
uint32_t head_dim, uint32_t batch_size,
157-
DType* data, IdType* indices, IdType* indptr,
158-
IdType* last_page_len,
159-
IdType* cooperative_aux_info)
160-
: num_heads(num_heads),
161-
page_size(page_size),
162-
head_dim(head_dim),
163-
batch_size(batch_size),
164-
data(data),
165-
indices(indices),
166-
indptr(indptr),
167-
last_page_len(last_page_len),
168-
cooperative_aux_info(cooperative_aux_info) {}
169-
170-
/*!
171-
* \brief Construct a paged key-value cache with auxiliary information for cooperative kernels
172-
* \param num_heads The number of heads
173-
* \param page_size The size of each page
174-
* \param head_dim The dimension of each head
175-
* \param batch_size The batch size
176-
* \param ptrs The array of pointers to each active page
177-
* \param indptr The page indptr array
178-
* \param last_page_len The offset of the last page for each request in the batch
179-
* \param cooperative_aux_info The auxiliary information used in cooperative kernels
180-
* \note This constructor should only be used when page_storage == kIndices
181-
*/
182-
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
183-
uint32_t head_dim, uint32_t batch_size,
184-
DType** ptrs, IdType* indptr,
185-
IdType* last_page_len,
186-
IdType* cooperative_aux_info)
187-
: num_heads(num_heads),
188-
page_size(page_size),
189-
head_dim(head_dim),
190-
batch_size(batch_size),
191-
ptrs(ptrs),
192-
indptr(indptr),
193-
last_page_len(last_page_len),
194-
cooperative_aux_info(cooperative_aux_info) {}
148+
indptr(indptr) {}
195149

196150
/*!
197151
* \brief Compute the offset of k element in the allocated buffer.

0 commit comments

Comments
 (0)