Skip to content

Commit 603c1ad

Browse files
put in a more readable heurisitic
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 355cb84 commit 603c1ad

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

hopper/flash_fwd_combine_kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,10 @@ class FlashAttnFwdCombine {
269269
// and batch. If the grid is more than 50% dense, we use the standard scheduling
270270
// algorithm since its more efficient at calculating the block coordinates.
271271
// NOTE: in varlen case args.seqlen_q is the max seqlen_q across all batches
272-
// if the density is over 50% we use the standard scheduling algo
273-
return cute::ceil_div(args.total_q, args.seqlen_q) >= cute::ceil_div(args.b, 2) ?
272+
// use lower bound to estimate when the density is more than 50%
273+
int lower_bound_on_non_empty_tiles = cute::ceil_div(args.total_q, kBlockM);
274+
int grid_size = args.b * cute::ceil_div(args.seqlen_q, kBlockM);
275+
return 2 * lower_bound_on_non_empty_tiles >= grid_size ?
274276
SchedulingAlgo::STANDARD :
275277
SchedulingAlgo::LINEARIZE_M_AND_BATCH;
276278
}

0 commit comments

Comments
 (0)