From a557a2fae55cd7dbc6014f86ceafa9d5591fb659 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 2 Oct 2025 10:50:05 -0700 Subject: [PATCH] bwd case when offset is 0 Summary: This case is for when we are not using bottom right mask. It should be slightly better perf in that case. # notes We note that backward is in general not stable. Sometimes you can get IMA. And numerics are not as good as we want it to be. Differential Revision: D83076701 --- .../sm100_fmha_bwd_kernel_tma_warpspecialized.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 0e60b24c3d..a0da0ed2ad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -1799,8 +1799,14 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if constexpr (std::is_base_of_v, Mask>) { int offset = get<1>(problem_shape) - get<0>(problem_shape); iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); - } else if constexpr (std::is_base_of_v, Mask>) { - int offset = get<1>(problem_shape) - get<0>(problem_shape); + } + else if constexpr ( + std::is_base_of_v, Mask> || + std::is_base_of_v, Mask> + ) { + int offset = std::is_base_of_v, Mask> + ? get<1>(problem_shape) - get<0>(problem_shape) + : 0; int k_max = (get<1>(blk_coord) + 1) * TileShapeK{}; int q_max = min(get<0>(problem_shape), k_max - offset + params.mainloop_params.window_size_left);