From 4c6a539a4d69241c612fea88b6af6c46d05eb542 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 15 Aug 2025 05:31:33 +0000 Subject: [PATCH] potential ima fix Signed-off-by: Lucas Wilkinson --- hopper/flash_fwd_combine_kernel.h | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index 3aa5484cbc..826ccef6f8 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -292,7 +292,6 @@ class FlashAttnFwdCombine { switch (choose_scheduling_algo(args)) { case SchedulingAlgo::STANDARD: { - unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK); unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM); return {num_blocks_m, num_blocks_k, static_cast(args.b)}; } @@ -426,15 +425,13 @@ class FlashAttnFwdCombine { *params.semaphore_to_reset = 0; } + if (batch >= params.b) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; int max_idx = seqlen * get<2>(params.shape_LSE_partial); - bool block_coord_valid = - block_coord.block_m < cute::ceil_div(max_idx, Int{}) && - block_coord.bidb < params.b; - if (!block_coord_valid) { return; } + if (m_block >= cute::ceil_div(max_idx, Int{})) { return; } int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); if (num_splits <= 1) { return; }