@@ -137,8 +137,8 @@ __device__ __forceinline__ void threadblock_sync_state(state_t<vec_size>& st, DT
137137 * \param vec_size The vector size used in the kernel.
138138 * \tparam DTypeIn The data type of v.
139139 * \tparam DTypeOut The data type of v_merged.
140- * \param v The partial v of index sets. (num_index_sets, n , h, d)
141- * \param s The logsumexp value of index sets. (num_index_sets, n , h)
140+ * \param v The partial v of index sets. (n, num_index_sets , h, d)
141+ * \param s The logsumexp value of index sets. (n, num_index_sets , h)
142142 * \param v_merged The merged v of index sets union. (n, h, d)
143143 * \param s_merged The merged logsumexp value of index sets union. (n, h)
144144 * \param num_heads The number of heads of v.
@@ -150,7 +150,6 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
150150 DTypeOut* __restrict__ v_merged, float * __restrict__ s_merged,
151151 uint32_t num_index_sets, uint32_t num_heads, uint32_t head_dim) {
152152 uint32_t tx = threadIdx .x , ty = threadIdx .y ;
153- uint32_t seq_len = gridDim .x ;
154153 uint32_t pos = blockIdx .x ;
155154 uint32_t head_idx = ty;
156155 state_t <vec_size> st;
@@ -159,9 +158,10 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
159158 v_merged_vec.fill (0 .f );
160159#pragma unroll 2
161160 for (uint32_t iter = 0 ; iter < num_index_sets; ++iter) {
162- float s = S[(iter * seq_len + pos ) * num_heads + head_idx];
161+ float s = S[(pos * num_index_sets + iter ) * num_heads + head_idx];
163162 vec_t <float , vec_size> v;
164- v.cast_load (V + ((iter * seq_len + pos) * num_heads + head_idx) * head_dim + tx * vec_size);
163+ v.cast_load (V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
164+ tx * vec_size);
165165 st.merge (v, s, 1 );
166166 }
167167
@@ -175,14 +175,14 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
175175/* !
176176 * \brief The CUDA kernel that merges self-attention states of a list of index sets,
177177 * accelerated for larget number of index sets.
178- * \param vec_size The vector size used in the kernel.
179- * \param bdx The blockDim.x used in the kernel.
180- * \param bdy The blockDim.y used in the kernel.
181- * \param num_smem_stages The number of stages of shared memory used in the kernel.
178+ * \tparam vec_size The vector size used in the kernel.
179+ * \tparam bdx The blockDim.x used in the kernel.
180+ * \tparam bdy The blockDim.y used in the kernel.
181+ * \tparam num_smem_stages The number of stages of shared memory used in the kernel.
182182 * \tparam DTypeIn The data type of v.
183183 * \tparam DTypeOut The data type of v_merged.
184- * \param v The partial v of index sets. (num_index_sets, n , h, d)
185- * \param s The logsumexp value of index sets. (num_index_sets, n , h)
184+ * \param V The partial v of index sets. (n, num_index_sets , h, d)
185+ * \param S The logsumexp value of index sets. (n, num_index_sets , h)
186186 * \param v_merged The merged v of index sets union. (n, h, d)
187187 * \param s_merged The merged logsumexp value of index sets union. (n, h)
188188 * \param num_heads The number of heads of v.
@@ -196,9 +196,8 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
196196 float * __restrict__ s_merged,
197197 uint32_t num_index_sets, uint32_t num_heads) {
198198 uint32_t tx = threadIdx .x , ty = threadIdx .y ;
199- uint32_t seq_len = gridDim .y ;
200- uint32_t pos = blockIdx .y ;
201- uint32_t head_idx = blockIdx .x ;
199+ uint32_t pos = blockIdx .x ;
200+ uint32_t head_idx = blockIdx .y ;
202201 state_t <vec_size> st;
203202 constexpr uint32_t vec_bits = sizeof (DTypeIn) * vec_size * 8 ;
204203 constexpr uint32_t head_dim = vec_size * bdx;
@@ -211,7 +210,8 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
211210 for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
212211 cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
213212 v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
214- V + (((iter * bdy + ty) * seq_len + pos) * num_heads + head_idx) * head_dim + tx * vec_size,
213+ V + ((pos * num_index_sets + (iter * bdy + ty)) * num_heads + head_idx) * head_dim +
214+ tx * vec_size,
215215 (iter * bdy + ty) < num_index_sets);
216216 cp_async::commit_group ();
217217 }
@@ -220,27 +220,111 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
220220 if (iter % bdx == 0 ) {
221221 s_smem[ty * bdx + tx] =
222222 iter * bdy + (ty * bdx + tx) < num_index_sets
223- ? S[((iter * bdy + ty * bdx + tx) * seq_len + pos ) * num_heads + head_idx]
223+ ? S[(pos * num_index_sets + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
224224 : 0 .f ;
225225 __syncthreads ();
226226 }
227227 cp_async::wait_group<num_smem_stages - 1 >();
228228 __syncthreads ();
229229 vec_t <float , vec_size> v;
230230 v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
231+ if (iter * bdy + ty < num_index_sets) {
232+ float s = s_smem[(iter % bdx) * bdy + ty];
233+ st.merge (v, s, 1 );
234+ }
231235 __syncthreads ();
232236 cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
233237 v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
234238 V +
235- ((((iter + num_smem_stages) * bdy + ty) * seq_len + pos) * num_heads + head_idx) *
239+ ((pos * num_index_sets + ((iter + num_smem_stages) * bdy + ty)) * num_heads +
240+ head_idx) *
236241 head_dim +
237242 tx * vec_size,
238243 (iter + num_smem_stages) * bdy + ty < num_index_sets);
239244 cp_async::commit_group ();
245+ }
246+ cp_async::wait_group<0 >();
247+ __syncthreads ();
248+
249+ st.normalize ();
250+ threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
251+ st.normalize ();
252+
253+ st.o .cast_store (v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
254+ if (s_merged != nullptr ) {
255+ s_merged[pos * num_heads + head_idx] = st.get_lse ();
256+ }
257+ }
258+
259+ /* !
260+ * \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
261+ * sets at each position might vary.
262+ * \tparam vec_size The vector size used in the kernel.
263+ * \tparam bdx The blockDim.x used in the kernel.
264+ * \tparam bdy The blockDim.y used in the kernel.
265+ * \tparam num_smem_stages The number of stages of shared memory used in the kernel.
266+ * \tparam DTypeIn The data type of v.
267+ * \tparam DTypeOut The data type of v_merged.
268+ * \param V The partial v of index sets. (nnz, h, d)
269+ * \param S The logsumexp value of index sets. (nnz, h)
270+ * \param indptr The start offsets of each position in the variable length array.
271+ * \param v_merged The merged v of index sets union. (n, h, d)
272+ * \param s_merged The merged logsumexp value of index sets union. (n, h)
273+ * \param num_heads The number of heads of v.
274+ * \param head_dim The dimension of each head.
275+ * \note s are logsumexp values with base 2.
276+ */
277+ template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
278+ typename DTypeOut, typename IdType>
279+ __global__ void VariableLengthMergeStatesKernel (DTypeIn* __restrict__ V, float * __restrict__ S,
280+ IdType* indptr, DTypeOut* __restrict__ v_merged,
281+ float * __restrict__ s_merged, uint32_t num_heads) {
282+ uint32_t tx = threadIdx .x , ty = threadIdx .y ;
283+ uint32_t pos = blockIdx .x ;
284+ uint32_t head_idx = blockIdx .y ;
285+ state_t <vec_size> st;
286+ constexpr uint32_t vec_bits = sizeof (DTypeIn) * vec_size * 8 ;
287+ constexpr uint32_t head_dim = vec_size * bdx;
288+
289+ extern __shared__ uint8_t smem[];
290+ DTypeIn* v_smem = (DTypeIn*)smem;
291+ float * s_smem = (float *)(smem + num_smem_stages * bdy * head_dim * sizeof (DTypeIn));
292+ const uint32_t num_index_sets = indptr[pos + 1 ] - indptr[pos];
293+
294+ #pragma unroll
295+ for (uint32_t iter = 0 ; iter < num_smem_stages; ++iter) {
296+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
297+ v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
298+ V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
299+ (iter * bdy + ty) < num_index_sets);
300+ cp_async::commit_group ();
301+ }
302+ #pragma unroll 4
303+ for (uint32_t iter = 0 ; iter < ceil_div (num_index_sets, bdy); ++iter) {
304+ if (iter % bdx == 0 ) {
305+ s_smem[ty * bdx + tx] =
306+ iter * bdy + (ty * bdx + tx) < num_index_sets
307+ ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
308+ : 0 .f ;
309+ __syncthreads ();
310+ }
311+ cp_async::wait_group<num_smem_stages - 1 >();
312+ __syncthreads ();
313+ vec_t <float , vec_size> v;
314+ v.cast_load (v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
240315 if (iter * bdy + ty < num_index_sets) {
241316 float s = s_smem[(iter % bdx) * bdy + ty];
242317 st.merge (v, s, 1 );
243318 }
319+ __syncthreads ();
320+ cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
321+ v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
322+ V +
323+ ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
324+ head_dim +
325+ tx * vec_size,
326+ (iter + num_smem_stages) * bdy + ty < num_index_sets);
327+ cp_async::commit_group ();
244328 }
245329 cp_async::wait_group<0 >();
246330 __syncthreads ();
@@ -346,7 +430,7 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge
346430 if (num_index_sets > 2 * (128 / bdx)) {
347431 constexpr uint32_t num_threads = 128 ;
348432 constexpr uint32_t bdy = num_threads / bdx;
349- dim3 nblks (num_heads, seq_len );
433+ dim3 nblks (seq_len, num_heads );
350434 dim3 nthrs (bdx, bdy);
351435 constexpr uint32_t num_smem_stages = 4 ;
352436 auto kernel = MergeStatesLargeNumIndexSetsKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
@@ -369,6 +453,30 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge
369453 return cudaSuccess;
370454}
371455
456+ template <typename DTypeIn, typename DTypeOut, typename IdType>
457+ cudaError_t VariableLengthMergeStates (DTypeIn* v, float * s, IdType* indptr, DTypeOut* v_merged,
458+ float * s_merged, uint32_t seq_len, uint32_t num_heads,
459+ uint32_t head_dim, cudaStream_t stream = nullptr ) {
460+ SWITCH_HEAD_DIM (head_dim, HEAD_DIM, {
461+ constexpr uint32_t vec_size = std::max (16U / sizeof (DTypeIn), HEAD_DIM / 32U );
462+ constexpr uint32_t bdx = HEAD_DIM / vec_size;
463+ constexpr uint32_t num_threads = 128 ;
464+ constexpr uint32_t bdy = num_threads / bdx;
465+ dim3 nblks (seq_len, num_heads);
466+ dim3 nthrs (bdx, bdy);
467+ constexpr uint32_t num_smem_stages = 4 ;
468+ auto kernel = VariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
469+ DTypeOut, IdType>;
470+ void * args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads};
471+ uint32_t smem_size =
472+ num_smem_stages * bdy * head_dim * sizeof (DTypeIn) + num_threads * sizeof (float );
473+ FLASHINFER_CUDA_CALL (
474+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
475+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
476+ });
477+ return cudaSuccess;
478+ }
479+
372480} // namespace flashinfer
373481
374482#endif // FLASHINFER_CASCADE_CUH_
0 commit comments