diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 919a7aa8c0..2cec9c713a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -269,14 +269,14 @@ def skcheck(self) -> str: return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true if self.pipeline_tag == "qr_async": if self.skpad == "t": - return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" + return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)" else: - return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" elif self.pipeline_tag in ["qr", "qs"]: if self.skpad == "t": return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) else: - return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" + return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)" elif self.pipeline_tag == "qr_async_trload": if self.skpad == "t": return "true" diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index 73b3c1e619..3f8071be32 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[]) "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" "also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch " "(group mode)") + .insert("s_qpad", + "-1", + "padded seqlen_q per batch (group mode only). " + "Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding") .insert("s_k", "-1", "seqlen_k, -1 means equal to s\n" "also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch " "(group mode)") + .insert("s_kpad", + "-1", + "padded seqlen_k per batch (group mode only). " + "Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") @@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); auto seqlen_qs = arg_parser.get_int_vec("s"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); auto seqlen_ks = arg_parser.get_int_vec("s_k"); + auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); bool i_perm = arg_parser.get_bool("iperm"); @@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser) nhead_k, seqlen_qs, seqlen_ks, + seqlen_qpads, + seqlen_kpads, hdim_q, hdim_v, i_perm, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 6cd1cd94fa..570a4bed82 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -114,9 +114,51 @@ struct fmha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* seqlen_k_ptr; + + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -203,7 +245,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, args.seqlen_k_ptr, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, @@ -315,6 +360,8 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) args.d_ptr, args.p_undrop, args.seqstart_q_ptr, + args.seqlen_q_ptr, + args.cu_seqlen_q_ptr, args.hdim_v, args.stride_do, args.stride_o, @@ -356,6 +403,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) args.dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, + args.seqlen_k_ptr, + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr, args.hdim_q, args.stride_dq, args.stride_dq_acc, diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index b6f2c8ca30..52adcdc21d 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -65,6 +65,8 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::index_t nhead_k, std::vector seqlen_qs, std::vector seqlen_ks, + std::vector seqlen_qpads, + std::vector seqlen_kpads, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, bool i_perm, @@ -119,13 +121,26 @@ bwd_result fmha_bwd_run(mode_enum mode, std::cerr << "dbias only exists when bias type is elementwise" << std::endl; return bwd_result::invalid_args; } - std::vector seqlen_kpads; - std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = - generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, {}, 0, false, random_engine); - ck_tile::ignore = seqlen_kpads; + + std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens( + mode, batch, seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads, 0, false, random_engine); + + bool use_qpadding = + mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1); + bool use_kpadding = + mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] != -1); + #if 0 + std::cout << "use_qpadding: " << use_qpadding << std::endl; + std::cout << "use_kpadding: " << use_kpadding << std::endl; std::cout << "seqlen_qs: " << seqlen_qs << std::endl; std::cout << "seqlen_ks: " << seqlen_ks << std::endl; + if (use_qpadding) { + std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl; + } + if (use_kpadding) { + std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; + } #endif mask_info mask = mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); @@ -146,8 +161,10 @@ bwd_result fmha_bwd_run(mode_enum mode, s_randval = true; } - const auto seqstart_q_host = to_seqstarts(seqlen_qs); - const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_q_host = + (use_qpadding ? to_seqstarts(seqlen_qpads) : to_seqstarts(seqlen_qs)); + const auto seqstart_k_host = + (use_kpadding ? to_seqstarts(seqlen_kpads) : to_seqstarts(seqlen_ks)); using TypeConfig = FmhaBwdTypeConfig; @@ -176,8 +193,11 @@ bwd_result fmha_bwd_run(mode_enum mode, { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + // When padding is enabled, use logical lengths for flop/bandwidth calculation + const int32_t real_seqlen_q = + use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]); + const int32_t real_seqlen_k = + use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]); if(max_seqlen_q < real_seqlen_q) { @@ -336,6 +356,10 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_q_dev(mode == mode_enum::batch ? 0 + : seqlen_qs.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_dev(mode == mode_enum::batch ? 0 + : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); @@ -349,6 +373,13 @@ bwd_result fmha_bwd_run(mode_enum mode, do_buf.ToDevice(do_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + if(mode == mode_enum::group) + { + std::vector seqlen_q_host(seqlen_qs.begin(), seqlen_qs.end()); + seqlen_q_dev.ToDevice(seqlen_q_host.data()); + std::vector seqlen_k_host(seqlen_ks.begin(), seqlen_ks.end()); + seqlen_k_dev.ToDevice(seqlen_k_host.data()); + } drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); @@ -440,6 +471,9 @@ bwd_result fmha_bwd_run(mode_enum mode, } }(); + const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr; + const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr; + return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -457,6 +491,9 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_acc_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), + seqlen_q_ptr_dev, + seqlen_k_ptr_dev, + nullptr, nullptr, shape_seqlen_q, shape_seqlen_k, @@ -551,8 +588,18 @@ bwd_result fmha_bwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + // When padding is enabled, use logical lengths instead of computing from padded + // prefix-sum + const ck_tile::index_t real_seqlen_q = + use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]); + const ck_tile::index_t real_seqlen_k = + use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]); + + // Skip forward reference computation for batches with zero length sequences + if(real_seqlen_q == 0 || real_seqlen_k == 0) + { + continue; + } // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); @@ -797,10 +844,23 @@ bwd_result fmha_bwd_run(mode_enum mode, dv_buf.FromDevice(dv_host.data()); dbias_buf.FromDevice(dbias_host.data()); + // Track the index into reference vectors (may differ from wb if batches were skipped) + ck_tile::index_t ref_idx = 0; + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + // When padding is enabled, use logical lengths instead of computing from padded + // prefix-sum + const ck_tile::index_t real_seqlen_q = + use_qpadding ? seqlen_qs[wb] : (seqstart_q_host[wb + 1] - seqstart_q_host[wb]); + const ck_tile::index_t real_seqlen_k = + use_kpadding ? seqlen_ks[wb] : (seqstart_k_host[wb + 1] - seqstart_k_host[wb]); + + // Skip validation for batches with zero length sequences + if(real_seqlen_q == 0 || real_seqlen_k == 0) + { + continue; + } // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); @@ -833,14 +893,14 @@ bwd_result fmha_bwd_run(mode_enum mode, // dP = dO@V x Z w/ dropout // dP = dO@V w/o dropout - auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + auto v_t_host_ref = v_host_refs[ref_idx].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o ck_tile::reference_batched_gemm( do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o if(p_drop > 0) { ck_tile::reference_batched_dropout( - dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop); } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) @@ -849,11 +909,13 @@ bwd_result fmha_bwd_run(mode_enum mode, AccDataType do_dot_o = 0; for(int o = 0; o < hdim_v; o++) { - do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * - ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + do_dot_o += + ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[ref_idx](i0, i1, o)); } - ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( - p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + ds_hp_host_ref(i0, i1, i2) = + ck_tile::type_convert(p_hp_host_refs[ref_idx](i0, i1, i2) * + (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); }, ds_hp_host_ref.mDesc.get_lengths()[0], ds_hp_host_ref.mDesc.get_lengths()[1], @@ -869,14 +931,14 @@ bwd_result fmha_bwd_run(mode_enum mode, // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout auto p_t_lp_host_ref = - p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + p_lp_host_refs[ref_idx].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m ck_tile:: reference_batched_gemm( p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m // dQ = scale * dS@K^T - auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n ck_tile::reference_batched_gemm( ds_lp_host_ref, k_t_host_ref, @@ -886,8 +948,8 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n // dK = scale * dS^T@Q^T - auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m - auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[ref_idx].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m ck_tile::reference_batched_gemm( ds_t_lp_host_ref, q_t_host_ref, @@ -961,6 +1023,9 @@ bwd_result fmha_bwd_run(mode_enum mode, break; } + + // Increment reference vector index for successfully validated batches + ref_idx++; } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 761def6d6a..383be6e099 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -182,19 +182,50 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; - // Optional cumulative sequence length arrays - // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] - - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - - // Group mode: seqstart_padded_* provide physical starts including PAD (optional) - const void* seqstart_padded_q_ptr = nullptr; // [batch+1] - const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + // Usage notes for sequence length pointer parameters: + // + // [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles + // MQA/GQA..."] + // + // With padding: + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence + // lengths. [array size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each + // sequence. [array size: batch] + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually + // exclusive. Use one set, not both. + // + // Batch mode: + // - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding) + // sequence lengths. [array size: batch + 1] + // - seqstart_* and seqlen_* pointers must be nullptr. + // + // Without padding: + // (Note: Physical length equals logical length) + // + // Group mode: + // - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array + // size: batch + 1] + // - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr. + // + // Batch mode: + // - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr. + // + const void* seqstart_q_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqstart_k_ptr = + nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode) + const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array + // [batch]. (Used in Group mode with padding) + const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) + const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length + // array [batch + 1]. (Used with padding) ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -555,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.o_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, + args.seqlen_q_ptr, args.seqlen_k_ptr, args.hdim_q, args.hdim_v, @@ -584,8 +616,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, - args.seqstart_padded_q_ptr, - args.seqstart_padded_k_ptr); + args.cu_seqlen_q_ptr, + args.cu_seqlen_k_ptr); } else { // create batch mode kernel arguments @@ -633,7 +665,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_kv_ptr); + args.cu_seqlen_k_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 0703af71e3..ca3cd51c57 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -313,16 +313,19 @@ fwd_result fmha_fwd_run(mode_enum mode, const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) - const bool has_group_padding = - (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || - (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); - const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || - !kv_eff_lens_per_batch.empty())); - const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); - const bool using_pagedkv = (0 < page_block_size); - const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + const bool has_group_q_padding = + mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] > 0); + const bool has_group_k_padding = + mode == mode_enum::group && (!seqlen_kpads.empty() && seqlen_kpads[0] > 0); + const bool has_group_padding = has_group_q_padding || has_group_k_padding; + const bool has_batch_q_padding = mode == mode_enum::batch && !q_eff_lens_per_batch.empty(); + const bool has_batch_k_padding = mode == mode_enum::batch && !kv_eff_lens_per_batch.empty(); + const bool has_batch_padding = has_batch_q_padding || has_batch_k_padding; + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; if((using_appendkv || using_pagedkv || using_splitkv) && - (has_group_padding || has_batch_efflens)) + (has_group_padding || has_batch_padding)) { std::cerr << "Padding (physical or effective lengths) is not supported with " "appendkv/splitkv/pagedkv pipelines" @@ -330,11 +333,12 @@ fwd_result fmha_fwd_run(mode_enum mode, return fwd_result::invalid_args; } - std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = + std::tie(seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads) = generate_missing_seqlens(mode, batch, seqlen_qs, seqlen_ks, + seqlen_qpads, seqlen_kpads, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, need_append_kvcache, @@ -346,7 +350,13 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; return fwd_result::invalid_args; } + if(seqlen_qpads[wb] > 0 && seqlen_qpads[wb] < seqlen_qs[wb]) + { + std::cerr << "qpad must be greater than or equal to seqlen for q" << std::endl; + return fwd_result::invalid_args; + } } + // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -357,6 +367,7 @@ fwd_result fmha_fwd_run(mode_enum mode, #if 0 std::cout << "seqlen_qs: " << seqlen_qs << std::endl; std::cout << "seqlen_ks: " << seqlen_ks << std::endl; + std::cout << "seqlen_qpads: " << seqlen_qpads << std::endl; std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl; #endif @@ -391,23 +402,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_q_host = to_seqstarts(seqlen_qs); const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_q_with_padding_host = to_seqstarts(seqlen_qpads); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - // Optional padded Q seqstarts (group-mode only) - std::vector seqstart_q_with_padding_host; - if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) - { - if(seqlen_qpads.size() < static_cast(batch)) - { - seqlen_qpads.resize(batch, seqlen_qpads.back()); - } - if(seqlen_qpads.size() == static_cast(batch)) - { - seqstart_q_with_padding_host = to_seqstarts( - ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); - } - } - // Optional batch-mode cumulative seqlen overrides std::vector cuq_cum, cukv_cum; if(mode == mode_enum::batch) @@ -514,19 +511,17 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen - const ck_tile::index_t shape_seqlen_q_lse = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch - ? seqlen_qs[0] - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() - : seqstart_q_with_padding_host.back())); + (mode == mode_enum::batch ? seqlen_qs[0] + : (has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_with_padding_host.back() + : seqstart_q_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] - : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() - : seqstart_k_with_padding_host.back())); + : (has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_with_padding_host.back() + : seqstart_k_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -580,7 +575,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} + lse ? std::array{shape_batch, nhead, shape_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -684,14 +679,18 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(int32_t)); ck_tile::DeviceMem seqstart_k_padded_buf( seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + // Buffers for query per-sequence logical (unpadded) lengths (used in group mode with padding + // enabled) + ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0); + // Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with + // kvcache or group mode with padding enabled) + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding + ? seqlen_ks.size() * sizeof(int32_t) + : 0); ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 : cuq_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem cu_seqlen_kv_buf( cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || - 0 <= seqlen_kpads[0] - ? seqlen_ks.size() * sizeof(int32_t) - : 0); ck_tile::DeviceMem cache_seqlen_k_buf( need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); @@ -787,7 +786,8 @@ fwd_result fmha_fwd_run(mode_enum mode, : seqstart_k_with_padding_host.data()); cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); - seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding ? seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); @@ -868,7 +868,7 @@ fwd_result fmha_fwd_run(mode_enum mode, print_vec("k_padded", seqlen_kpads); } } - else if(has_batch_efflens) + else if(has_batch_padding) { // derive effective lengths from cumulative arrays if present if(!cuq_cum.empty()) @@ -970,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -986,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -1051,14 +1051,6 @@ fwd_result fmha_fwd_run(mode_enum mode, args.lse_ptr = lse_buf.GetDeviceBuffer(); args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = - (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = - (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] - ? seqlen_k_buf.GetDeviceBuffer() - : nullptr); - args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -1102,27 +1094,54 @@ fwd_result fmha_fwd_run(mode_enum mode, args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } - // Group-mode: optional physical padded starts for Q/K + // Sequence length and padding parameters (mode-specific) if(mode == mode_enum::group) { - args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() - ? nullptr - : seqstart_q_padded_buf.GetDeviceBuffer()); - args.seqstart_padded_k_ptr = - (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + // Group mode: use physical (padded) cumulative starts + logical per-sequence + // lengths + + // Physical cumulative starts (including padding) + args.seqstart_q_ptr = + has_group_q_padding && !seqstart_q_with_padding_host.empty() + ? seqstart_q_padded_buf.GetDeviceBuffer() + : seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = + has_group_k_padding && !seqstart_k_with_padding_host.empty() + ? seqstart_k_padded_buf.GetDeviceBuffer() + : seqstart_k.GetDeviceBuffer(); + + // Logical (unpadded) per-sequence lengths, used when padding is enabled + args.seqlen_q_ptr = + (has_group_q_padding && !seqstart_q_with_padding_host.empty()) + ? seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.seqlen_k_ptr = + (has_group_k_padding && !seqstart_k_with_padding_host.empty()) + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr; + // Cumulative lengths not used in group mode + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; } - - // Batch-mode: optional cumulative effective seqlen overrides - if(mode == mode_enum::batch) + else // mode == mode_enum::batch { - args.cu_seqlen_q_ptr = cuq_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_q_buf.GetDeviceBuffer()); - args.cu_seqlen_kv_ptr = cukv_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_kv_buf.GetDeviceBuffer()); + // Batch mode: use cumulative logical lengths for tail padding + + // seqstart pointers not used in batch mode + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; + + // seqlen_q_ptr/seqlen_k_ptr not used in batch mode + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + + // Cumulative logical lengths for effective length handling + args.cu_seqlen_q_ptr = has_batch_q_padding && !cuq_cum.empty() + ? cu_seqlen_q_buf.GetDeviceBuffer() + : nullptr; + args.cu_seqlen_k_ptr = has_batch_k_padding && !cukv_cum.empty() + ? cu_seqlen_kv_buf.GetDeviceBuffer() + : nullptr; } } else if constexpr(std::is_same_v>) @@ -1148,6 +1167,15 @@ fwd_result fmha_fwd_run(mode_enum mode, args.batch_stride_o_acc = batch_stride_o_acc; args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); } else if constexpr(std::is_same_v>) { @@ -1159,6 +1187,15 @@ fwd_result fmha_fwd_run(mode_enum mode, args.cache_batch_idx = (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); } } }; @@ -1360,16 +1397,19 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + // Use physical offset if padding info is valid (not -1) and buffers are available const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] - : seqstart_q_with_padding_host[wb])); + : ((seqstart_q_with_padding_host.empty() || seqlen_qpads[0] < 0) + ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 - : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] - : seqstart_k_with_padding_host[wb])); + : ((seqstart_k_with_padding_host.empty() || seqlen_kpads[0] < 0) + ? seqstart_k_host[wb] + : seqstart_k_with_padding_host[wb])); ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); @@ -1718,8 +1758,14 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cerr << "OUT mismatch found at batch: " << wb << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; + << "\tseqstart_q (logical): " << seqstart_q_host << std::endl + << "\tseqstart_q (physical): " << seqstart_q_with_padding_host + << std::endl + << "\tseqstart_k (logical): " << seqstart_k_host << std::endl + << "\tseqstart_k (physical): " << seqstart_k_with_padding_host + << std::endl + << "\tquery_offset used: " << query_offset << std::endl + << "\tkey_offset used: " << key_offset << std::endl; break; } @@ -1727,10 +1773,8 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - const ck_tile::index_t query_offset_lse = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 7f44d87180..0303ded238 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -142,12 +142,14 @@ auto randints(ForwardIterator first, */ template std::tuple, + std::vector, std::vector, std::vector> generate_missing_seqlens(mode_enum mode, ck_tile::index_t batch, const std::vector& q_val, const std::vector& k_val, + const std::vector& q_pad_val, const std::vector& k_pad_val, ck_tile::index_t seqlen_k_min, bool need_append_kvcache, @@ -177,7 +179,7 @@ generate_missing_seqlens(mode_enum mode, return seqlen_ks; }(); auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding - + auto s_qpad = std::vector(batch, -1); // s_k should be greater than or equal to seqlen_k_min if provided if(s_k.back() < seqlen_k_min) { @@ -187,13 +189,14 @@ generate_missing_seqlens(mode_enum mode, throw std::runtime_error(msg.str()); } - return std::make_tuple(s_q, s_k, s_kpad); + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); } else { std::vector s_q; std::vector s_k; std::vector s_kpad; + std::vector s_qpad; ck_tile::index_t idx = 0; for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) { @@ -205,9 +208,15 @@ generate_missing_seqlens(mode_enum mode, ? -1 : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; + ck_tile::index_t qp = + q_pad_val.empty() + ? -1 + : q_pad_val[std::min(idx, static_cast(q_pad_val.size()) - 1)]; + s_q.push_back(q); s_k.push_back(k < 0 ? q : k); s_kpad.push_back(kp); + s_qpad.push_back(qp); // s_k should be greater than or equal to seqlen_k_min if(s_k.back() < seqlen_k_min) @@ -228,8 +237,9 @@ generate_missing_seqlens(mode_enum mode, s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back()); } - return std::make_tuple(s_q, s_k, s_kpad); + return std::make_tuple(s_q, s_k, s_qpad, s_kpad); } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 980dfb06ae..668fab3fd3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -313,7 +313,10 @@ struct FmhaBwdDQDKDVKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; - const int32_t* seqlen_k_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* seqlen_k_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional + const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std::conditional_t; @@ -520,7 +523,10 @@ struct FmhaBwdDQDKDVKernel void* dq_acc_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, + const void* cu_seqlen_q_ptr, + const void* cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -594,7 +600,10 @@ struct FmhaBwdDQDKDVKernel {}, // placeholder for deterministic reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(cu_seqlen_q_ptr), + reinterpret_cast(cu_seqlen_k_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -736,10 +745,29 @@ struct FmhaBwdDQDKDVKernel batch_offset_randval = query_start * kargs.stride_randval; } - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - if(kargs.seqlen_k_ptr != nullptr) + // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = + kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; + } + + // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k + if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + else if(kargs.seqlen_k_ptr != nullptr) { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } @@ -749,6 +777,12 @@ struct FmhaBwdDQDKDVKernel kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; } + // skip if logical lengths are zero + if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0) + { + return; + } + // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if constexpr(!kUseQrQtrDorPipeline) @@ -1246,6 +1280,8 @@ struct FmhaBwdOGradDotOKernel struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs { const int32_t* seqstart_q_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std:: @@ -1293,6 +1329,8 @@ struct FmhaBwdOGradDotOKernel void* d_ptr, float p_undrop, const void* seqstart_q_ptr, + const void* seqlen_q_ptr, + const void* cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, @@ -1311,7 +1349,9 @@ struct FmhaBwdOGradDotOKernel nhead_stride_do, nhead_stride_o, nhead_stride_d}, - reinterpret_cast(seqstart_q_ptr)}; + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(cu_seqlen_q_ptr)}; return kargs; } @@ -1355,9 +1395,23 @@ struct FmhaBwdOGradDotOKernel batch_offset_do = query_start * kargs.stride_do; batch_offset_d = query_start; - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqlen_q_ptr + ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + } + // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) @@ -1521,6 +1575,10 @@ struct FmhaBwdConvertQGradKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; // per-batch actual length [batch] + const int32_t* seqlen_k_ptr; // per-batch actual length [batch] + const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional + const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional }; using Kargs = std::conditional_t(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr)}; + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), + reinterpret_cast(seqlen_k_ptr), + reinterpret_cast(cu_seqlen_q_ptr), + reinterpret_cast(cu_seqlen_k_ptr)}; if constexpr(kIsDeterministic) { @@ -1632,13 +1698,41 @@ struct FmhaBwdConvertQGradKernel batch_offset_dq = query_start * kargs.stride_dq; batch_offset_dq_acc = query_start * kargs.stride_dq_acc; - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + const ck_tile::index_t physical_seqlen_q = + adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + kargs.seqlen_q = kargs.seqlen_q_ptr + ? static_cast(kargs.seqlen_q_ptr[i_batch]) + : physical_seqlen_q; + } + if constexpr(kIsDeterministic) { const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + const ck_tile::index_t physical_seqlen_k = + adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + + // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k + if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } + else + { + kargs.seqlen_k = + kargs.seqlen_k_ptr + ? static_cast(kargs.seqlen_k_ptr[i_batch]) + : physical_seqlen_k; + } } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index dafe99febe..f539c9d7e9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -296,8 +296,8 @@ struct FmhaFwdKernel // Optional cumulative sequence length pointers for batch mode // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD + const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD + const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -316,12 +316,12 @@ struct FmhaFwdKernel { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; + const int32_t* seqlen_q_ptr; const int32_t* seqlen_k_ptr; - // Optional cumulative padded sequence starts (including PAD tokens) - // Used solely to compute memory offsets when sequences are physically padded. - const int32_t* seqstart_padded_q_ptr = nullptr; - const int32_t* seqstart_padded_k_ptr = nullptr; + // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays + const int32_t* cu_seqlen_q_ptr = nullptr; + const int32_t* cu_seqlen_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -379,8 +379,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -471,8 +471,8 @@ struct FmhaFwdKernel kargs.init_logits_soft_cap(logits_soft_cap); } - kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; - kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -522,8 +522,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -570,7 +570,7 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_kv_ptr); + cu_seqlen_k_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -619,8 +619,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, - const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -667,7 +667,7 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_kv_ptr); + cu_seqlen_k_ptr); } template @@ -681,6 +681,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -711,8 +712,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -746,6 +747,7 @@ struct FmhaFwdKernel {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -804,8 +806,8 @@ struct FmhaFwdKernel kargs.min_seqlen_q = min_seqlen_q; } - kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); - kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); + kargs.cu_seqlen_q_ptr = reinterpret_cast(cu_seqlen_q_ptr); + kargs.cu_seqlen_k_ptr = reinterpret_cast(cu_seqlen_k_ptr); return kargs; } @@ -821,6 +823,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -850,8 +853,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -863,6 +866,7 @@ struct FmhaFwdKernel o_ptr, seqstart_q_ptr, seqstart_k_ptr, + seqlen_q_ptr, seqlen_k_ptr, hdim_q, hdim_v, @@ -892,8 +896,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - seqstart_padded_q_ptr, - seqstart_padded_k_ptr); + cu_seqlen_q_ptr, + cu_seqlen_k_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -908,6 +912,7 @@ struct FmhaFwdKernel void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, + const void* seqlen_q_ptr, const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -937,8 +942,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, - const void* seqstart_padded_q_ptr = nullptr, - const void* seqstart_padded_k_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -950,6 +955,7 @@ struct FmhaFwdKernel o_ptr, seqstart_q_ptr, seqstart_k_ptr, + seqlen_q_ptr, seqlen_k_ptr, hdim_q, hdim_v, @@ -979,8 +985,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), - seqstart_padded_q_ptr, - seqstart_padded_k_ptr); + cu_seqlen_q_ptr, + cu_seqlen_k_ptr); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1109,46 +1115,52 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // logical and physical (padded) starts - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - // DRAM base offsets use physical padded starts - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + // Use seqstart_q_ptr and seqstart_k_ptr for physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + // DRAM base offsets use physical starts + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } if constexpr(kStoreLSE) { - // LSE stays indexed by unpadded starts - batch_offset_lse = query_start_unpadded; + // LSE follows the physical layout to stay consistent with other tensors + batch_offset_lse = query_start; } if constexpr(kHasDropout) { - batch_offset_randval = query_start_padded * kargs.stride_randval; + batch_offset_randval = query_start * kargs.stride_randval; } - batch_offset_o = query_start_padded * kargs.stride_o; + batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + } if constexpr(kSkipMinSeqlenQ) { @@ -1168,6 +1180,11 @@ struct FmhaFwdKernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; @@ -1201,10 +1218,10 @@ struct FmhaFwdKernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } @@ -1603,39 +1620,46 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - - const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr - ? kargs.seqstart_padded_q_ptr[i_batch] - : query_start_unpadded; - const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr - ? kargs.seqstart_padded_k_ptr[i_batch] - : key_start_unpadded; - - batch_offset_q = query_start_padded * kargs.stride_q; - batch_offset_k = key_start_padded * kargs.stride_k; + // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for + // physical starts + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start_padded * kargs.stride_v; + batch_offset_v = key_start * kargs.stride_v; } else { // col-major V: offset along seqlen dimension is scalar index - batch_offset_v = key_start_padded; + batch_offset_v = key_start; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start_padded * kargs.stride_bias; + batch_offset_bias = query_start * kargs.stride_bias; } - // LSE layout is [nhead, total_seqlen], index by unpadded start - batch_offset_lse = query_start_unpadded; - batch_offset_o = query_start_padded * kargs.stride_o; + // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O + batch_offset_lse = query_start; + batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode - kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + if(kargs.seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch]; + } + else if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + else + { + kargs.seqlen_q = + kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; + } // # of required blocks is different in each groups, terminate unnecessary blocks // earlier @@ -1648,6 +1672,11 @@ struct FmhaFwdKernel { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } + else if(kargs.cu_seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; + } else { kargs.seqlen_k = @@ -1677,10 +1706,10 @@ struct FmhaFwdKernel kargs.seqlen_q = kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; } - if(kargs.cu_seqlen_kv_ptr != nullptr) + if(kargs.cu_seqlen_k_ptr != nullptr) { kargs.seqlen_k = - kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch]; } } diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index ca7b7b6324..e769a79c08 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -5,6 +5,7 @@ endif() set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") + set(TEST_NAME "test_ck_tile_fmha") function(add_gtest_fwd test_group) diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index 1279b98383..3eea02f888 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -77,6 +77,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param) nhead_k, {seqlen_q}, {seqlen_k}, + {-1}, + {-1}, hdim_q, hdim_v, i_perm, @@ -246,3 +248,741 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, Values(true) // deterministic )); TEST_P(Deterministic, DataTypeConfig) { fmha_bwd_test(GetParam()); } + +// ============================================================================ +// Q/KV Padding Tests - High Priority +// ============================================================================ + +// 1. BasicQPadding: Test Q padding only (K/V have no padding) +class BasicQPadding : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + BasicQPadding, + Combine(Values(mode_enum::group), // Only group mode supports padding + HDimValues, + Values(std::tuple{true, true}), // perm + Values("n"), // no bias for basic test + Values(false), // use_dbias + Values(0.0f), // no dropout + Values(std::tuple{0, 0, false}), // seed/offset/prefs + ValuesIn([]() { + // Define test cases with Q padding: seqlen_q < seqlen_qpad + // Format: {batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str} + // Note: Will set seqlen_qpad separately in the test + std::vector test_cases; + + // Small padding: logical length close to physical + test_cases.push_back(std::tuple{2, 2, 2, 127, 128, "0"}); // Q: 127->128 + test_cases.push_back(std::tuple{3, 4, 2, 250, 256, "0"}); // Q: 250->256 + + // Medium padding: ~20-30% padding + test_cases.push_back(std::tuple{2, 2, 1, 180, 256, "0"}); // Q: 180->256 + test_cases.push_back(std::tuple{3, 3, 3, 350, 512, "1"}); // Q: 350->512, causal + + // Large padding: ~50% padding + test_cases.push_back(std::tuple{2, 4, 2, 128, 256, "0"}); // Q: 128->256 + test_cases.push_back(std::tuple{2, 2, 2, 200, 512, "2"}); // Q: 200->512, causal + + return test_cases; + }()), + Values(false) // deterministic + )); + +TEST_P(BasicQPadding, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + // Set up Q padding: physical length larger than logical + std::vector seqlen_qs(batch, seqlen_q); + std::vector seqlen_ks(batch, seqlen_k); + + // Calculate physical Q length (padded) + ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; // Round up to multiple of 64 + if(seqlen_q > 256) + seqlen_qpad = ((seqlen_q + 127) / 128) * 128; // Larger alignment for longer sequences + + std::vector seqlen_qpads(batch, seqlen_qpad); + std::vector seqlen_kpads(batch, seqlen_k); // No K padding + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, // scale + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for Q padding with hdim_q=" << hdim_q; + ASSERT_EQ(result, bwd_result::success); +} + +// 2. BasicKVPadding: Test K/V padding only (Q has no padding) +class BasicKVPadding : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + BasicKVPadding, + Combine(Values(mode_enum::group), + HDimValues, + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + ValuesIn([]() { + std::vector test_cases; + + // Small K/V padding + test_cases.push_back(std::tuple{2, 2, 2, 128, 127, "0"}); // K: 127->128 + test_cases.push_back(std::tuple{3, 4, 2, 256, 250, "0"}); // K: 250->256 + + // Medium K/V padding + test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "0"}); // K: 180->256 + test_cases.push_back(std::tuple{3, 3, 3, 512, 350, "1"}); // K: 350->512 + + // Large K/V padding + test_cases.push_back(std::tuple{2, 4, 2, 256, 128, "0"}); // K: 128->256 + test_cases.push_back(std::tuple{2, 2, 2, 512, 200, "2"}); // K: 200->512 + + return test_cases; + }()), + Values(false))); + +TEST_P(BasicKVPadding, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + std::vector seqlen_qs(batch, seqlen_q); + std::vector seqlen_ks(batch, seqlen_k); + + // No Q padding + std::vector seqlen_qpads(batch, seqlen_q); + + // Set up K/V padding + ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64; + if(seqlen_k > 256) + seqlen_kpad = ((seqlen_k + 127) / 128) * 128; + std::vector seqlen_kpads(batch, seqlen_kpad); + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for K/V padding with hdim_q=" << hdim_q; + ASSERT_EQ(result, bwd_result::success); +} + +// 3. QKVPadding: Test both Q and K/V padding simultaneously +class QKVPadding : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + QKVPadding, + Combine(Values(mode_enum::group), + HDimValues, + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + ValuesIn([]() { + std::vector test_cases; + + // Both Q and K have small padding + test_cases.push_back(std::tuple{2, 2, 2, 120, 125, "0"}); // Q:120->128, K:125->128 + + // Both Q and K have medium padding + test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "0"}); // Q:180->256, K:200->256 + test_cases.push_back(std::tuple{3, 3, 3, 300, 350, "1"}); // Q:300->320, K:350->384 + + // Both Q and K have large padding + test_cases.push_back(std::tuple{2, 2, 1, 150, 180, "0"}); // Q:150->256, K:180->256 + test_cases.push_back(std::tuple{2, 4, 2, 256, 300, "2"}); // Q:256->384, K:300->384 + + // Asymmetric padding (Q more padded than K) + test_cases.push_back(std::tuple{2, 2, 2, 100, 200, "0"}); // Q:100->128, K:200->256 + + // Asymmetric padding (K more padded than Q) + test_cases.push_back(std::tuple{2, 3, 1, 200, 100, "0"}); // Q:200->256, K:100->128 + + return test_cases; + }()), + Values(false))); + +TEST_P(QKVPadding, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + std::vector seqlen_qs(batch, seqlen_q); + std::vector seqlen_ks(batch, seqlen_k); + + // Set up both Q and K/V padding + ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; + if(seqlen_q > 256) + seqlen_qpad = ((seqlen_q + 127) / 128) * 128; + + ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64; + if(seqlen_k > 256) + seqlen_kpad = ((seqlen_k + 127) / 128) * 128; + + std::vector seqlen_qpads(batch, seqlen_qpad); + std::vector seqlen_kpads(batch, seqlen_kpad); + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for Q+K/V padding with hdim_q=" << hdim_q; + ASSERT_EQ(result, bwd_result::success); +} + +// 4. ZeroLengthPadding: Test zero-length sequences with padding +class ZeroLengthPadding : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + ZeroLengthPadding, + Combine(Values(mode_enum::group), + Values(std::tuple{64, -1}, + std::tuple{128, -1}), // Limited hdim for edge cases + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + Values( + // Test case 1: First batch has zero Q length + std::tuple{3, 2, 2, 0, 128, "0"}, + // Test case 2: Middle batch has zero Q length (multi-batch) + std::tuple{3, 2, 1, 100, 128, "0"}, + // Test case 3: Last batch has zero Q length + std::tuple{3, 3, 3, 150, 200, "0"}, + // Test case 4: Zero K length (first batch) + std::tuple{3, 2, 2, 128, 0, "0"}, + // Test case 5: Mixed zero lengths with padding + std::tuple{4, 2, 2, 80, 100, "0"}), + Values(false))); + +TEST_P(ZeroLengthPadding, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + // Create varied sequence lengths with some zero-length sequences + std::vector seqlen_qs; + std::vector seqlen_ks; + std::vector seqlen_qpads; + std::vector seqlen_kpads; + + for(int b = 0; b < batch; ++b) + { + // Create pattern with zero-length sequences + ck_tile::index_t q_len, k_len; + + if(seqlen_q == 0 && b == 1) // Middle batch zero Q + { + q_len = (b == 1) ? 0 : ((b == 0) ? 100 : 80); + k_len = seqlen_k; + } + else if(seqlen_k == 0 && b == 0) // First batch zero K + { + q_len = seqlen_q; + k_len = (b == 0) ? 0 : 100; + } + else + { + // Varied lengths + q_len = (b == 0 && seqlen_q == 0) ? 0 : (seqlen_q + b * 10); + k_len = seqlen_k + b * 15; + } + + seqlen_qs.push_back(q_len); + seqlen_ks.push_back(k_len); + + // Add padding for non-zero lengths + ck_tile::index_t qpad = (q_len == 0) ? 0 : ((q_len + 63) / 64) * 64; + ck_tile::index_t kpad = (k_len == 0) ? 0 : ((k_len + 63) / 64) * 64; + + seqlen_qpads.push_back(qpad); + seqlen_kpads.push_back(kpad); + } + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for zero-length padding"; + ASSERT_EQ(result, bwd_result::success); +} + +// ============================================================================ +// Q/KV Padding Tests - Medium Priority +// ============================================================================ + +// 5. VariedPaddingRatios: Test different padding ratios (waste ratios) +class VariedPaddingRatios : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + VariedPaddingRatios, + Combine(Values(mode_enum::group), + HDimValues, + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + ValuesIn([]() { + std::vector test_cases; + + // Minimal waste: ~1-5% padding (logical ≈ physical - small delta) + test_cases.push_back( + std::tuple{2, 2, 2, 127, 127, "0"}); // Q:127->128 (~0.8%), K:127->128 + test_cases.push_back( + std::tuple{2, 4, 2, 252, 250, "0"}); // Q:252->256 (~1.6%), K:250->256 + test_cases.push_back(std::tuple{2, 2, 1, 509, 505, "1"}); // Q:509->512, K:505->512 + + // Low waste: ~10-20% padding + test_cases.push_back( + std::tuple{2, 3, 3, 220, 210, "0"}); // Q:220->256 (~16%), K:210->256 + test_cases.push_back( + std::tuple{3, 2, 2, 440, 420, "0"}); // Q:440->512 (~16%), K:420->512 + test_cases.push_back(std::tuple{2, 4, 2, 350, 340, "1"}); // Q:350->384, K:340->384 + + // Medium waste: ~30-40% padding + test_cases.push_back( + std::tuple{2, 2, 2, 180, 170, "0"}); // Q:180->256 (~42%), K:170->256 + test_cases.push_back( + std::tuple{2, 3, 1, 320, 310, "0"}); // Q:320->384 (~20%), K:310->384 + test_cases.push_back(std::tuple{3, 2, 2, 350, 340, "2"}); // Q:350->512, K:340->512 + + // High waste: ~50%+ padding + test_cases.push_back( + std::tuple{2, 2, 2, 130, 130, "0"}); // Q:130->256 (~97%), K:130->256 + test_cases.push_back( + std::tuple{2, 4, 2, 260, 260, "0"}); // Q:260->512 (~97%), K:260->512 + test_cases.push_back( + std::tuple{2, 2, 1, 200, 200, "1"}); // Q:200->256 (~28%), K:200->256 + + // Extreme waste: very small logical vs large physical + test_cases.push_back(std::tuple{2, 2, 2, 65, 70, "0"}); // Q:65->128, K:70->128 + test_cases.push_back(std::tuple{2, 3, 3, 100, 90, "0"}); // Q:100->128, K:90->128 + + return test_cases; + }()), + Values(false))); + +TEST_P(VariedPaddingRatios, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + std::vector seqlen_qs(batch, seqlen_q); + std::vector seqlen_ks(batch, seqlen_k); + + // Calculate padding based on common alignment strategies + auto calc_pad = [](ck_tile::index_t len) -> ck_tile::index_t { + if(len <= 64) + return 64; + else if(len <= 128) + return 128; + else if(len <= 256) + return 256; + else if(len <= 384) + return 384; + else if(len <= 512) + return 512; + else + return ((len + 127) / 128) * 128; + }; + + std::vector seqlen_qpads(batch, calc_pad(seqlen_q)); + std::vector seqlen_kpads(batch, calc_pad(seqlen_k)); + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for varied padding ratios"; + ASSERT_EQ(result, bwd_result::success); +} + +// 6. PaddingWithMask: Test padding combined with various mask types +class PaddingWithMask : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P( + TestCkTileFmhaBwd, + PaddingWithMask, + Combine(Values(mode_enum::group), + Values(std::tuple{64, -1}, std::tuple{128, -1}), // Focus on common sizes + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + ValuesIn([]() { + std::vector test_cases; + + // No mask with padding (baseline) + test_cases.push_back(std::tuple{2, 2, 2, 200, 180, "0"}); + + // Causal mask (top-left) with Q padding + test_cases.push_back(std::tuple{2, 2, 2, 200, 256, "1"}); // Q padded, K exact + test_cases.push_back(std::tuple{2, 4, 2, 180, 200, "t"}); // Both padded, causal + + // Causal mask (bottom-right) with K/V padding + test_cases.push_back(std::tuple{2, 2, 1, 256, 180, "2"}); // K padded, Q exact + test_cases.push_back( + std::tuple{2, 3, 3, 200, 180, "b"}); // Both padded, bottom-right + + // Sliding window attention with padding + test_cases.push_back(std::tuple{2, 2, 2, 200, 190, "t:64,32"}); // SWA + padding + test_cases.push_back(std::tuple{2, 4, 2, 180, 170, "b:32,64"}); // SWA + padding + test_cases.push_back(std::tuple{3, 2, 1, 220, 210, "t:100,50"}); // Larger window + + // Sliding window with asymmetric padding + test_cases.push_back(std::tuple{2, 2, 2, 150, 250, "t:80,40"}); // Q more padded + test_cases.push_back(std::tuple{2, 3, 3, 250, 150, "b:50,70"}); // K more padded + + // Mixed scenarios + test_cases.push_back(std::tuple{2, 4, 2, 190, 185, "t:50,50"}); // Symmetric window + test_cases.push_back(std::tuple{3, 2, 2, 300, 280, "1"}); // Multi-batch causal + + return test_cases; + }()), + Values(false))); + +TEST_P(PaddingWithMask, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; + + std::vector seqlen_qs(batch, seqlen_q); + std::vector seqlen_ks(batch, seqlen_k); + + // Apply padding + ck_tile::index_t seqlen_qpad = ((seqlen_q + 63) / 64) * 64; + ck_tile::index_t seqlen_kpad = ((seqlen_k + 63) / 64) * 64; + + if(seqlen_q > 256) + seqlen_qpad = ((seqlen_q + 127) / 128) * 128; + if(seqlen_k > 256) + seqlen_kpad = ((seqlen_k + 127) / 128) * 128; + + std::vector seqlen_qpads(batch, seqlen_qpad); + std::vector seqlen_kpads(batch, seqlen_kpad); + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for padding with mask"; + ASSERT_EQ(result, bwd_result::success); +} + +// 7. MultiBatchPadding: Test multiple batches with different padding configurations +class MultiBatchPadding : public TestWithParam +{ +}; + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, + MultiBatchPadding, + Combine(Values(mode_enum::group), + Values(std::tuple{64, -1}, std::tuple{128, -1}), + Values(std::tuple{true, true}), + Values("n"), + Values(false), + Values(0.0f), + Values(std::tuple{0, 0, false}), + Values( + // 3 batches with varied Q/K lengths and padding + std::tuple{3, 2, 2, 150, 200, "0"}, + // 4 batches with different patterns + std::tuple{4, 3, 3, 180, 220, "0"}, + // 5 batches with mixed scenarios + std::tuple{5, 2, 1, 120, 160, "1"}, + // 3 batches with causal mask + std::tuple{3, 4, 2, 200, 180, "t"}, + // 4 batches with sliding window + std::tuple{4, 2, 2, 160, 140, "t:50,30"}), + Values(false))); + +TEST_P(MultiBatchPadding, DataTypeConfig) +{ + auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam(); + auto [hdim_q, hdim_v] = hdims; + auto [i_perm, o_perm] = perm; + auto [drop_seed, drop_offset, drop_prefs] = drop_misc; + auto [batch, nhead, nhead_k, base_seqlen_q, base_seqlen_k, mask_str] = dims_mask; + + // Create varied sequence lengths for each batch + std::vector seqlen_qs; + std::vector seqlen_ks; + std::vector seqlen_qpads; + std::vector seqlen_kpads; + + for(int b = 0; b < batch; ++b) + { + // Generate varied lengths across batches + // Pattern: decreasing, increasing, or random variation + ck_tile::index_t q_len, k_len; + + switch(b % 3) + { + case 0: // Decreasing + q_len = base_seqlen_q - b * 20; + k_len = base_seqlen_k - b * 25; + break; + case 1: // Increasing + q_len = base_seqlen_q + b * 15; + k_len = base_seqlen_k + b * 20; + break; + case 2: // Mixed + q_len = base_seqlen_q + (b % 2 == 0 ? 10 : -10) * b; + k_len = base_seqlen_k + (b % 2 == 0 ? -15 : 15) * b; + break; + } + + // Ensure positive lengths + q_len = std::max(64, q_len); + k_len = std::max(64, k_len); + + seqlen_qs.push_back(q_len); + seqlen_ks.push_back(k_len); + + // Calculate different padding strategies per batch + ck_tile::index_t qpad, kpad; + + if(b % 4 == 0) + { + // Tight padding (minimal waste) + qpad = ((q_len + 31) / 32) * 32; + kpad = ((k_len + 31) / 32) * 32; + } + else if(b % 4 == 1) + { + // Medium padding + qpad = ((q_len + 63) / 64) * 64; + kpad = ((k_len + 63) / 64) * 64; + } + else if(b % 4 == 2) + { + // Loose padding + qpad = ((q_len + 127) / 128) * 128; + kpad = ((k_len + 127) / 128) * 128; + } + else + { + // Mixed: Q tight, K loose + qpad = ((q_len + 31) / 32) * 32; + kpad = ((k_len + 127) / 128) * 128; + } + + seqlen_qpads.push_back(qpad); + seqlen_kpads.push_back(kpad); + } + + auto result = fmha_bwd_run( + mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + hdim_q, + hdim_v, + i_perm, + o_perm, + 0, + bias_str, + use_dbias, + p_drop, + drop_seed, + drop_offset, + drop_prefs, + mask_str, + det, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 1, + stream_config); + + if(result == bwd_result::no_instance) + GTEST_SKIP() << "No instance for multi-batch padding"; + ASSERT_EQ(result, bwd_result::success); +}