@@ -292,7 +292,6 @@ class FlashAttnFwdCombine {
292292
293293 switch (choose_scheduling_algo (args)) {
294294 case SchedulingAlgo::STANDARD: {
295- unsigned int num_blocks_k = cute::ceil_div (args.dv , kBlockK );
296295 unsigned int num_blocks_m = cute::ceil_div (args.seqlen_q * args.num_heads , kBlockM );
297296 return {num_blocks_m, num_blocks_k, static_cast <unsigned int >(args.b )};
298297 }
@@ -426,15 +425,13 @@ class FlashAttnFwdCombine {
426425 *params.semaphore_to_reset = 0 ;
427426 }
428427
428+ if (batch >= params.b ) { return ; }
429429 flash::SeqlenInfo<Varlen, kBlockM > seqlen_info{batch, size<0 >(params.shape_LSE_partial ), params.cu_seqlens , params.seqused };
430430 int const offset = seqlen_info.offset ;
431431 int const seqlen = seqlen_info.seqlen ;
432432 int max_idx = seqlen * get<2 >(params.shape_LSE_partial );
433433
434- bool block_coord_valid =
435- block_coord.block_m < cute::ceil_div (max_idx, Int<kBlockM >{}) &&
436- block_coord.bidb < params.b ;
437- if (!block_coord_valid) { return ; }
434+ if (m_block >= cute::ceil_div (max_idx, Int<kBlockM >{})) { return ; }
438435
439436 int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr [batch] : get<1 >(params.shape_LSE_partial );
440437 if (num_splits <= 1 ) { return ; }
0 commit comments