diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 0e60b24c3d..a0da0ed2ad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -1799,8 +1799,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); - } else if constexpr (std::is_base_of_v, Mask>) { - int offset = get<1>(problem_shape) - get<0>(problem_shape); + } + else if constexpr ( + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> + ) { + int offset = std::is_base_of_v, Mask> + ? get<1>(problem_shape) - get<0>(problem_shape) + : 0; int k_max = (get<1>(blk_coord) + 1) * TileShapeK{}; int q_max = min(get<0>(problem_shape), k_max - offset + params.mainloop_params.window_size_left);