@@ -122,16 +122,24 @@ class FlashAttnFwdCombine {
122122 using ShapeLSE = cute::Shape<int32_t , int32_t , int32_t >; // (seqlen, head, batch)
123123 using StrideLSE = cute::Stride<_1, int64_t , int64_t >; // (seqlen, head, batch)
124124
125+ struct BlockCoord {
126+ int block_m;
127+ int block_k;
128+ int bidb;
129+ };
130+
125131 struct SharedStorage : cute::aligned_struct<128 > {
126132 cute::array_aligned<float , cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127133 cute::array_aligned<int , kBlockM > smem_max_valid_split;
128134 cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
135+ BlockCoord block_coord;
129136 };
130137
131138 static constexpr int SharedStorageSize = sizeof (SharedStorage);
132139
133140 // Device side arguments
134141 struct Arguments {
142+ int b;
135143 ElementPartial const * const ptr_O_partial;
136144 ShapeOPartial const shape_O_partial;
137145 StrideOPartial const stride_O_partial;
@@ -149,7 +157,8 @@ class FlashAttnFwdCombine {
149157 };
150158
151159 // Kernel entry point API
152- struct Params {
160+ struct CollectiveParams {
161+ int b;
153162 ElementPartial const * const ptr_O_partial;
154163 ShapeOPartial const shape_O_partial;
155164 StrideOPartial const stride_O_partial;
@@ -169,10 +178,11 @@ class FlashAttnFwdCombine {
169178
170179 // Convert to underlying arguments. In this case, a simple copy for the aliased type.
171180 static
172- Params
181+ CollectiveParams
173182 to_underlying_arguments (Arguments const & args) {
174183 assert (get<1 >(args.shape_LSE_partial ) <= kMaxSplits );
175184 return {
185+ args.b ,
176186 args.ptr_O_partial ,
177187 args.shape_O_partial ,
178188 args.stride_O_partial ,
@@ -191,33 +201,243 @@ class FlashAttnFwdCombine {
191201 };
192202 }
193203
204+ struct SchedulerArguments {
205+ int b;
206+ int seqlen_q;
207+ int total_q;
208+ int num_heads;
209+ int dv;
210+ int const * cu_seqlens_q;
211+ int const * seqused_q;
212+ };
213+
214+ struct StaticTileScheduler {
215+ struct Params {};
216+ static Params to_underlying_arguments (SchedulerArguments const & args) { return {}; }
217+
218+ SharedStorage& shared_storage;
219+ CUTE_DEVICE StaticTileScheduler (SharedStorage& shared_storage): shared_storage(shared_storage) {}
220+
221+ static dim3 get_grid_shape (SchedulerArguments const & args) {
222+ unsigned int num_blocks_k = cute::ceil_div (args.dv , kBlockK );
223+ unsigned int num_blocks_m = cute::ceil_div (args.seqlen_q * args.num_heads , kBlockM );
224+ return {num_blocks_m, num_blocks_k, static_cast <unsigned int >(args.b )};
225+ }
226+
227+ CUTE_DEVICE BlockCoord get_block_coord (Params const & params) {
228+ int block_m = blockIdx.x ;
229+ int block_k = blockIdx.y ;
230+ int bidb = blockIdx.z ;
231+ return {block_m, block_k, bidb};
232+ }
233+ };
234+
235+ struct StaticVarlenTileScheduler {
236+ //
237+ // For varlen we have two Scheduling algos:
238+ // 1) STANDARD, same as StaticTileScheduler
239+ // 2) LINEARIZE_M_AND_BATCH, this flattens the tiled M dimension and
240+ // batch dimension into a linear tile index. The grid is then a
241+ // 2D grid of (tile_id, k_block). We then map the linear tile id
242+ // to (m_block, bidb) in the get_block_coord function. This mapping
243+ // is non-trivial since each batch element can have a different
244+ // number of m_blocks. This has overhead when computing the block
245+ // coordinates, but it is more efficient when prefills and decodes
246+ // are mixed since in that case the STANDARD scheduling algo will
247+ // have a lot of empty (no work) blocks in the grid.
248+ //
249+
250+ enum SchedulingAlgo {
251+ STANDARD, // Same as StaticTileScheduler
252+ LINEARIZE_M_AND_BATCH, // Linearize the M and batch dimensions into a single tile index
253+ };
254+
255+ struct Params {
256+ int b;
257+ int num_heads;
258+ int const * const cu_seqlens_q;
259+ int const * const seqused_q;
260+ SchedulingAlgo algo;
261+ };
262+
263+ SharedStorage& shared_storage;
264+ CUTE_DEVICE StaticVarlenTileScheduler (SharedStorage& shared_storage): shared_storage(shared_storage) {}
265+
266+ static SchedulingAlgo choose_scheduling_algo (SchedulerArguments const & args) {
267+ // Choose the scheduling algorithm based on how dense the grid of tiles that
268+ // do actual work is. If the grid is more then 50% sparse, we linearize the M
269+ // and batch. If the grid is more than 50% dense, we use the standard scheduling
270+ // algorithm since its more efficient at calculating the block coordinates.
271+ // NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
272+ // use lower bound to estimate when the density is more than 50%
273+ int lower_bound_on_non_empty_tiles = cute::ceil_div (args.total_q , kBlockM );
274+ int grid_size = args.b * cute::ceil_div (args.seqlen_q , kBlockM );
275+ return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
276+ SchedulingAlgo::STANDARD :
277+ SchedulingAlgo::LINEARIZE_M_AND_BATCH;
278+ }
279+
280+ static Params to_underlying_arguments (SchedulerArguments const & args) {
281+ return {
282+ args.b ,
283+ args.num_heads ,
284+ args.cu_seqlens_q ,
285+ args.seqused_q ,
286+ choose_scheduling_algo (args)
287+ };
288+ }
289+
290+ static dim3 get_grid_shape (SchedulerArguments const & args) {
291+ unsigned int num_blocks_k = cute::ceil_div (args.dv , kBlockK );
292+
293+ switch (choose_scheduling_algo (args)) {
294+ case SchedulingAlgo::STANDARD: {
295+ unsigned int num_blocks_k = cute::ceil_div (args.dv , kBlockK );
296+ unsigned int num_blocks_m = cute::ceil_div (args.seqlen_q * args.num_heads , kBlockM );
297+ return {num_blocks_m, num_blocks_k, static_cast <unsigned int >(args.b )};
298+ }
299+ case SchedulingAlgo::LINEARIZE_M_AND_BATCH: {
300+ // rough worst case upper bound on the number of blocks required
301+ // (assuming each batch has an additional partial block)
302+ unsigned int num_blocks_m = cute::ceil_div (args.total_q * args.num_heads , kBlockM ) + args.b ;
303+ return {num_blocks_m, num_blocks_k, 1 };
304+ }}
305+
306+ // rough worst case upper bound on the number of blocks required
307+ // (assuming each batch has an additional partial block)
308+ unsigned int num_blocks_m = cute::ceil_div (args.total_q * args.num_heads , kBlockM ) + args.b ;
309+ return {num_blocks_m, num_blocks_k, 1 };
310+ }
311+
312+ CUTE_DEVICE BlockCoord get_block_coord_linearized_m_and_batch (Params const & params) {
313+ int num_heads = params.num_heads ;
314+ int curr_tile_id = blockIdx.x ;
315+
316+ // Scan through the batches find the batch that contains the current
317+ // tile_id. Compute using only the first warp of the block.
318+ if (threadIdx.x < 32 ) {
319+ // We compute linearized tile index start and ends for each batch
320+ // in groups of 32 in parallel
321+ int group_start_bidb = -(cutlass::NumThreadsPerWarp);
322+ int group_end_bidb = 0 ;
323+ int group_end_tile_id = 0 ;
324+ int group_start_tile_id = 0 ;
325+ int group_total_num_tiles = 0 ;
326+
327+ int local_num_m_blocks = 0 ;
328+ int local_num_m_blocks_cumulative = 0 ;
329+
330+ do {
331+ group_start_bidb += cutlass::NumThreadsPerWarp;
332+ group_end_bidb += cutlass::NumThreadsPerWarp;
333+
334+ auto get_num_m_blocks = [&](int bidb) {
335+ if (bidb >= params.b ) return 0 ;
336+ flash::SeqlenInfo<Varlen, kBlockM > seqlen_info{bidb, 0 , params.cu_seqlens_q , params.seqused_q };
337+ return cute::ceil_div (seqlen_info.seqlen * num_heads, Int<kBlockM >{}());
338+ };
339+
340+ // Cumulative number of blocks for the next 31 batches
341+ local_num_m_blocks = get_num_m_blocks (group_start_bidb + threadIdx.x );
342+ local_num_m_blocks_cumulative = warp_prefix_sum (local_num_m_blocks);
343+ // Total number of blocks for the next 32 batches
344+ group_total_num_tiles = warp_shfl_get_last (local_num_m_blocks_cumulative);
345+
346+ group_start_tile_id = group_end_tile_id;
347+ group_end_tile_id += group_total_num_tiles;
348+ } while (curr_tile_id >= group_end_tile_id && group_end_bidb < params.b );
349+
350+ int local_batch_end_tile_id = group_start_tile_id + local_num_m_blocks_cumulative;
351+ // Find the last batch idx in the group where `local_batch_end_tile_id <= curr_tile_id`
352+ // these values below are now common to all threads in the warp
353+ int batch_idx_in_group = warp_last_true_laneid (local_batch_end_tile_id <= curr_tile_id);
354+ int batch_num_m_blocks = warp_shfl_get (local_num_m_blocks, batch_idx_in_group);
355+ int batch_m_start_tile_id = group_start_tile_id + (batch_idx_in_group > 0 ?
356+ warp_shfl_get (local_num_m_blocks_cumulative, batch_idx_in_group - 1 ) : 0 );
357+
358+ int bidb = group_start_bidb + batch_idx_in_group;
359+ int block_m = curr_tile_id - batch_m_start_tile_id;
360+ // NOTE(lucas): not sure why this causes a block_k unused warning
361+ // just inlined `blockIdx.y` to suppress the warning
362+ // int block_k = blockIdx.y;
363+ // shared_storage.block_coord = {block_m, block_k, bidb};
364+ BlockCoord block_coord{block_m, static_cast <int >(blockIdx.y ), bidb};
365+ if (threadIdx.x == 0 ) { shared_storage.block_coord = block_coord; }
366+ }
367+
368+ __syncthreads ();
369+ return shared_storage.block_coord ;
370+ }
371+
372+
373+ CUTE_DEVICE BlockCoord get_block_coord_standard (Params const & params) {
374+ int block_m = blockIdx.x ;
375+ int block_k = blockIdx.y ;
376+ int bidb = blockIdx.z ;
377+ return {block_m, block_k, bidb};
378+ }
379+
380+ CUTE_DEVICE BlockCoord get_block_coord (Params const & params) {
381+ switch (params.algo ) {
382+ case SchedulingAlgo::STANDARD:
383+ return get_block_coord_standard (params);
384+ case SchedulingAlgo::LINEARIZE_M_AND_BATCH:
385+ return get_block_coord_linearized_m_and_batch (params);
386+ }
387+ return {0 , 0 , 0 }; // Should never reach here
388+ }
389+ };
390+
391+ using TileScheduler = std::conditional_t <
392+ Varlen,
393+ StaticVarlenTileScheduler,
394+ StaticTileScheduler
395+ >;
396+
397+ using SchedulerParams = typename TileScheduler::Params;
398+
399+ struct Params {
400+ CollectiveParams params;
401+ SchedulerParams scheduler_params;
402+ };
403+
194404 CUTLASS_DEVICE
195405 void
196- operator ()(Params const & params, char * smem_buf) {
406+ operator ()(Params const & kernel_params, char * smem_buf) {
407+ CollectiveParams const & params = kernel_params.params ;
197408
198409 SharedStorage& shared_storage = *reinterpret_cast <SharedStorage*>(smem_buf);
410+ TileScheduler tile_scheduler{shared_storage};
411+
199412 Tensor sLSE = make_tensor (make_smem_ptr (shared_storage.smem_lse_partial .data ()), SmemLayoutLSE{});
200413 Tensor sMaxValidSplit = make_tensor (make_smem_ptr (shared_storage.smem_max_valid_split .data ()), Shape<Int<kBlockM >>{});
201414 Tensor sO = make_tensor (make_smem_ptr (shared_storage.smem_o_partial .data ()), SmemLayoutO{});
202415
203416 int const thread_idx = threadIdx.x ;
204- int const m_block = blockIdx.x ;
205- int const k_block = blockIdx.y ;
206- int const batch = blockIdx.z ;
207- int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr [batch] : get<1 >(params.shape_LSE_partial );
417+
418+ BlockCoord block_coord = tile_scheduler.get_block_coord (kernel_params.scheduler_params );
419+
420+ int const m_block = block_coord.block_m ;
421+ int const k_block = block_coord.block_k ;
422+ int const batch = block_coord.bidb ;
208423
209424 if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1 ) {
210425 cutlass::arch::wait_on_dependent_grids ();
211426 *params.semaphore_to_reset = 0 ;
212427 }
213- if (num_splits <= 1 ) { return ; }
428+
214429 flash::SeqlenInfo<Varlen, kBlockM > seqlen_info{batch, size<0 >(params.shape_LSE_partial ), params.cu_seqlens , params.seqused };
215430 int const offset = seqlen_info.offset ;
216431 int const seqlen = seqlen_info.seqlen ;
217432 int max_idx = seqlen * get<2 >(params.shape_LSE_partial );
218- if constexpr (Varlen) {
219- if (m_block * kBlockM >= max_idx) { return ; }
220- }
433+
434+ bool block_coord_valid =
435+ block_coord.block_m < cute::ceil_div (max_idx, Int<kBlockM >{}) &&
436+ block_coord.bidb < params.b ;
437+ if (!block_coord_valid) { return ; }
438+
439+ int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr [batch] : get<1 >(params.shape_LSE_partial );
440+ if (num_splits <= 1 ) { return ; }
221441
222442 cutlass::FastDivmod seqlen_divmod_dynamic (seqlen);
223443
0 commit comments