@@ -30,56 +30,53 @@ namespace flashinfer {
3030
3131class BatchDecodeHandler {
3232 public:
33- float * GetTempFloatBuffer () const { return float_buffer_; }
33+ template <typename DType>
34+ DType* GetTempFloatBuffer () const {
35+ return (DType*)float_buffer_;
36+ }
3437 template <typename IdType>
3538 IdType* GetNewIndPtr () const {
3639 return (IdType*)int_buffer_;
3740 }
3841 template <typename IdType>
3942 IdType* GetNewLastPageLen () const {
4043 if (int_buffer_ != nullptr ) {
41- return ((IdType*)int_buffer_) + new_batch_size_ + 1 ;
42- } else {
43- return nullptr ;
44- }
45- }
46- // cooperative_aux_info starts with cooperative_indptr
47- template <typename IdType>
48- IdType* GetCooperativeAuxInfo () const {
49- if (int_buffer_ != nullptr ) {
50- return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1 ;
44+ return ((IdType*)int_buffer_) + batch_size_after_partition_ + 1 ;
5145 } else {
5246 return nullptr ;
5347 }
5448 }
5549 template <typename IdType>
56- IdType* GetCooperativeIndPtr () const {
50+ IdType* GetChunkIndPtr () const {
5751 if (int_buffer_ != nullptr ) {
58- return ((IdType*)int_buffer_) + 2 * new_batch_size_ + 1 ;
52+ return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ + 1 ;
5953 } else {
6054 return nullptr ;
6155 }
6256 }
6357 template <typename IdType>
64- IdType* GetBatchIndexMap () const {
58+ IdType* GetBatchIdxMap () const {
6559 if (int_buffer_ != nullptr ) {
66- return ((IdType*)int_buffer_) + 3 * new_batch_size_ + 2 ;
60+ return ((IdType*)int_buffer_) + 2 * batch_size_after_partition_ +
61+ batch_size_before_partition_ + 2 ;
6762 } else {
6863 return nullptr ;
6964 }
7065 }
7166 template <typename IdType>
7267 IdType* GetChunkStartPos () const {
7368 if (int_buffer_ != nullptr ) {
74- return ((IdType*)int_buffer_) + 4 * new_batch_size_ + 2 ;
69+ return ((IdType*)int_buffer_) + 3 * batch_size_after_partition_ +
70+ batch_size_before_partition_ + 2 ;
7571 } else {
7672 return nullptr ;
7773 }
7874 }
7975 template <typename IdType>
80- IdType* GetSeqLengthsBeforeSplit () const {
76+ IdType* GetSeqLengthsBeforePartition () const {
8177 if (int_buffer_ != nullptr ) {
82- return ((IdType*)int_buffer_) + 5 * new_batch_size_ + 2 ;
78+ return ((IdType*)int_buffer_) + 4 * batch_size_after_partition_ +
79+ batch_size_before_partition_ + 2 ;
8380 } else {
8481 return nullptr ;
8582 }
@@ -89,30 +86,33 @@ class BatchDecodeHandler {
8986 cudaError_t BeginForward (IdType* indptr, IdType* last_page_len, uint32_t batch_size,
9087 uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
9188 uint32_t page_size, RotaryMode rotary_mode) {
89+ batch_size_before_partition_ = batch_size;
9290 uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
9391 auto work_estimation_func =
9492 BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, DTypeIn, DTypeOut, IdType>;
9593 FLASHINFER_CUDA_CALL (work_estimation_func (
9694 tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
9795 num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
98- new_batch_size_ = new_batch_size;
96+ batch_size_after_partition_ = new_batch_size;
9997 if (tmp_size > 0 ) {
100- FLASHINFER_CUDA_CALL (cudaMallocAsync (&float_buffer_, sizeof (float ) * tmp_size, stream_));
101- FLASHINFER_CUDA_CALL (
102- cudaMallocAsync (&int_buffer_, sizeof (IdType) * (6 * new_batch_size + 2 ), stream_));
103- FLASHINFER_CUDA_CALL (SplitPagedCacheKVComputeAuxiliaryInfo (
98+ FLASHINFER_CUDA_CALL (cudaMallocAsync (&float_buffer_, tmp_size, stream_));
99+ FLASHINFER_CUDA_CALL (cudaMallocAsync (
100+ &int_buffer_, sizeof (IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2 ),
101+ stream_));
102+ FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
104103 max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
105- GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetCooperativeIndPtr <IdType>(),
106- GetBatchIndexMap <IdType>(), GetChunkStartPos<IdType>(),
107- GetSeqLengthsBeforeSplit <IdType>(), stream_));
104+ GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr <IdType>(),
105+ GetBatchIdxMap <IdType>(), GetChunkStartPos<IdType>(),
106+ GetSeqLengthsBeforePartition <IdType>(), stream_));
108107 }
109108 forward_started_ = true ;
110109 return cudaSuccess;
111110 }
112111
113112 cudaError_t EndForward () {
114113 forward_started_ = false ;
115- new_batch_size_ = 0 ;
114+ batch_size_before_partition_ = 0 ;
115+ batch_size_after_partition_ = 0 ;
116116 if (float_buffer_ != nullptr ) {
117117 FLASHINFER_CUDA_CALL (cudaFreeAsync (float_buffer_, stream_));
118118 float_buffer_ = nullptr ;
@@ -126,23 +126,26 @@ class BatchDecodeHandler {
126126
127127 bool IsForwardStarted () const { return forward_started_; }
128128
129- uint32_t GetNewBatchSize () const { return new_batch_size_; }
129+ uint32_t GetBatchSizeBeforePartition () const { return batch_size_before_partition_; }
130+
131+ uint32_t GetBatchSizeAfterPartition () const { return batch_size_after_partition_; }
130132
131133 cudaStream_t GetCUDAStream () const { return stream_; }
132134
133135 void SetCUDAStream (cudaStream_t stream) { stream_ = stream; }
134136
135137 BatchDecodeHandler ()
136- : new_batch_size_ (0U ),
138+ : batch_size_after_partition_ (0U ),
137139 float_buffer_ (nullptr ),
138140 int_buffer_(nullptr ),
139141 forward_started_(false ),
140142 stream_(nullptr ) {}
141143 ~BatchDecodeHandler () { EndForward (); }
142144
143145 private:
144- uint32_t new_batch_size_;
145- float * float_buffer_;
146+ uint32_t batch_size_before_partition_;
147+ uint32_t batch_size_after_partition_;
148+ void * float_buffer_;
146149 void * int_buffer_;
147150 bool forward_started_;
148151 cudaStream_t stream_;
@@ -253,14 +256,19 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
253256 float rope_scale = 1 .f, float rope_theta = 1e4 ,
254257 cudaStream_t stream = nullptr ) {
255258 paged_kv_t <page_storage, DTypeIn, IdType> new_paged_kv = paged_kv;
256- float * tmp = handler->GetTempFloatBuffer ();
259+ kv_partition_info_t <IdType> kv_partition_info;
260+ DTypeOut* tmp = handler->GetTempFloatBuffer <DTypeOut>();
257261 if (handler->IsForwardStarted ()) {
258262 if (tmp != nullptr ) {
259263 // create auxiliary information for cooperative kernels
260- new_paged_kv.batch_size = handler->GetNewBatchSize ();
264+ new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition ();
261265 new_paged_kv.indptr = handler->GetNewIndPtr <IdType>();
262266 new_paged_kv.last_page_len = handler->GetNewLastPageLen <IdType>();
263- new_paged_kv.cooperative_aux_info = handler->GetCooperativeAuxInfo <IdType>();
267+ kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition ();
268+ kv_partition_info.chunk_indptr = handler->GetChunkIndPtr <IdType>();
269+ kv_partition_info.batch_idx_map = handler->GetBatchIdxMap <IdType>();
270+ kv_partition_info.chunk_start_pos = handler->GetChunkStartPos <IdType>();
271+ kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition <IdType>();
264272 }
265273 } else {
266274 std::cerr << " Please call BatchDecodeHandler's BeginForward() before calling "
@@ -269,7 +277,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
269277 abort ();
270278 }
271279 return BatchDecodeWithPagedKVCache<page_storage, DTypeIn, DTypeOut, IdType>(
272- q, new_paged_kv, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale, rope_theta, stream);
280+ q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
281+ rope_theta, stream);
273282}
274283
275284template <PageStorage page_storage, uint32_t GROUP_SIZE, uint32_t HEAD_DIM, RotaryMode ROTARY_MODE,
0 commit comments