Skip to content

Commit 57b4e68

Browse files
potential ima fix (#80)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 2d3b750 commit 57b4e68

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

hopper/flash_fwd_combine_kernel.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ class FlashAttnFwdCombine {
292292

293293
switch (choose_scheduling_algo(args)) {
294294
case SchedulingAlgo::STANDARD: {
295-
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
296295
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
297296
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
298297
}
@@ -426,15 +425,13 @@ class FlashAttnFwdCombine {
426425
*params.semaphore_to_reset = 0;
427426
}
428427

428+
if (batch >= params.b) { return; }
429429
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
430430
int const offset = seqlen_info.offset;
431431
int const seqlen = seqlen_info.seqlen;
432432
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
433433

434-
bool block_coord_valid =
435-
block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
436-
block_coord.bidb < params.b;
437-
if (!block_coord_valid) { return; }
434+
if (m_block >= cute::ceil_div(max_idx, Int<kBlockM>{})) { return; }
438435

439436
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
440437
if (num_splits <= 1) { return; }

0 commit comments

Comments
 (0)