Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,14 @@ def skcheck(self) -> str:
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag in ["qr", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
elif self.pipeline_tag == "qr_async_trload":
if self.skpad == "t":
return "true"
Expand Down
12 changes: 12 additions & 0 deletions example/ck_tile/01_fmha/example_fmha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ auto create_args(int argc, char* argv[])
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
Expand Down Expand Up @@ -96,7 +104,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
Expand Down Expand Up @@ -130,6 +140,8 @@ auto run(const ck_tile::ArgParser& arg_parser)
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
Expand Down
57 changes: 54 additions & 3 deletions example/ck_tile/01_fmha/fmha_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,51 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;

// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
Expand Down Expand Up @@ -203,7 +245,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
Expand Down Expand Up @@ -315,6 +360,8 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
Expand Down Expand Up @@ -356,6 +403,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
Expand Down
Loading