Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support page kvcache in AMD ROCm #1198

Merged
merged 52 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b5a4204
Integrate ck branch of ck_tile/fa_bwd_opt
rocking5566 Jul 22, 2024
1e3416d
Assume dq and q share the same stride
rocking5566 Jul 25, 2024
bc78de1
update ck
rocking5566 Jul 25, 2024
46d1cff
Integrate more stride of dq_acc
rocking5566 Jul 28, 2024
eac0e38
Revert fwd dropout
rocking5566 Jul 28, 2024
b14f245
Fix paremeter order
rocking5566 Jul 28, 2024
3180632
Integrate ck with more stride
rocking5566 Jul 30, 2024
c7ac11f
update the limit of hdim of bwd
rocking5566 Jul 30, 2024
6af216c
Check argument
rocking5566 Aug 2, 2024
05b657e
Add test_flash_attn_causal
rocking5566 Aug 2, 2024
dbe28cb
Support unpad lse
rocking5566 Aug 7, 2024
7b712b2
Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, t…
rocking5566 Aug 7, 2024
5346b5b
Fix stride and Kn0
rocking5566 Aug 7, 2024
381bbdd
Fix CK sync issue
rocking5566 Aug 12, 2024
6ac697c
Fix typo
rocking5566 Aug 12, 2024
bc5dd34
Merge commit '3669b25206d5938e3cc74a5f7860e31c38af8204' into ck_impro…
rocking5566 Aug 13, 2024
909d66f
Update CK for changing of fmha_fwd_args
rocking5566 Aug 13, 2024
6928524
Add kvcache tmp
rocking5566 Aug 17, 2024
23000f7
Add kvcache
rocking5566 Aug 19, 2024
2c80b86
Fix comment
rocking5566 Aug 19, 2024
d2ed413
Sync behavior with ck
rocking5566 Aug 19, 2024
c0c2f8f
Update CK to develop
rocking5566 Aug 19, 2024
e037142
remove large test case
rocking5566 Aug 19, 2024
d38c59b
Merge pull request #70 from ROCm/ck_bwd_opt
rocking5566 Aug 19, 2024
a37da96
Merge remote-tracking branch 'origin/ck_improve_v0.1.1' into ck_tile/…
rocking5566 Aug 20, 2024
d8de7a6
Add kvcache test
rocking5566 Aug 20, 2024
a84d12d
Fix page_block_size in arg
rocking5566 Aug 20, 2024
5b4546c
Minor fix
rocking5566 Aug 20, 2024
d6aac9e
Fix stride error
rocking5566 Aug 20, 2024
ae24800
Update seqlen of kvcache before splitkv
rocking5566 Aug 21, 2024
e2d3f5b
Fix compile error
rocking5566 Aug 21, 2024
bb7a439
Fix bug of hdim is not 8x
rocking5566 Aug 21, 2024
7b18b87
Fit ck arg
rocking5566 Aug 23, 2024
94e054f
support adaptive num_splits
rocking5566 Aug 26, 2024
9316aa6
add more tests
rocking5566 Aug 26, 2024
7815e3b
Refine test tolerance
rocking5566 Aug 26, 2024
27095f2
update CK
rocking5566 Aug 26, 2024
4a25f60
Move override_num_splits_if_necessary into cpp
rocking5566 Aug 26, 2024
22eee22
update ck
rocking5566 Aug 27, 2024
007ae03
Update ck
rocking5566 Aug 28, 2024
7259227
Merge pull request #74 from ROCm/ck_tile/kvcache
rocking5566 Aug 28, 2024
444ab9f
Merge branch 'Dao-AILab:main' into ck_improve_v0.1.1
rocking5566 Aug 28, 2024
4c0f9d2
Support different flag for different version of hip
rocking5566 Aug 28, 2024
89ac30b
remove coerce-illegal, becasue this is not required in FA
rocking5566 Aug 29, 2024
a381df5
Update ck to fix xcratch memory
rocking5566 Aug 30, 2024
6635e24
Add coerce-illegal in some version
rocking5566 Aug 30, 2024
cc01a17
Merge pull request #77 from ROCm/ck_tile/rocm6.2-flag
rocking5566 Aug 30, 2024
1cb8f8d
Add compile flag for rtn rounding
rocking5566 Aug 30, 2024
ba86d74
remove redundant init
rocking5566 Aug 31, 2024
8c4f9cd
Using env var to switch rounding mode
rocking5566 Sep 2, 2024
ece97c7
update ck
rocking5566 Sep 2, 2024
b40c1a0
Merge pull request #78 from ROCm/ck_tile/bf16_rtn
rocking5566 Sep 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 386 files
41 changes: 32 additions & 9 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ mha_fwd(at::Tensor &q,
c10::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand Down Expand Up @@ -89,11 +89,34 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
}
34 changes: 34 additions & 0 deletions csrc/flash_attn_ck/flash_common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"

namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
return num_splits;

hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
return num_splits;

// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;

const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;

if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);

return num_splits;
}

} // namespace flash
40 changes: 39 additions & 1 deletion csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
Expand All @@ -35,4 +35,42 @@ static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

} // namespace flash
103 changes: 71 additions & 32 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi)
bool enable_alibi,
bool deterministic)
{
return fmha_bwd_traits{head_size,
head_size,
Expand All @@ -20,7 +21,9 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout};
has_dropout,
false, // s_randval
deterministic};
}

fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
Expand All @@ -39,6 +42,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor dq_acc,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
Expand All @@ -49,41 +53,57 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
uint64_t drop_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t nhead_stride_q = q.stride(2);

// k: (batch_size, seqlen_k, nheads_k, hdim)
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t nhead_stride_k = k.stride(2);

// v: (batch_size, seqlen_k, nheads_k, hdim)
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t nhead_stride_v = v.stride(2);

// o: (batch_size, seqlen_q, nheads, hdim)
// dq: (batch_size, seqlen_q, nheads, hdim)
// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
// do: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t stride_o = out.stride(1);
ck_tile::index_t nhead_stride_o = out.stride(2);

// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (batch_size, nheads, seqlen_q)
// d: (batch_size, nheads, seqlen_q)
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);

ck_tile::index_t stride_q = q.stride(1);
ck_tile::index_t stride_k = k.stride(1);
ck_tile::index_t stride_v = v.stride(1);
ck_tile::index_t stride_o = out.stride(1);
// do: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t stride_do = dout.stride(1);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t stride_dv = dv.stride(1);

ck_tile::index_t nhead_stride_q = q.stride(2);
ck_tile::index_t nhead_stride_k = k.stride(2);
ck_tile::index_t nhead_stride_v = v.stride(2);
ck_tile::index_t nhead_stride_o = out.stride(2);
ck_tile::index_t nhead_stride_do = dout.stride(2);
ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);

ck_tile::index_t batch_stride_q = q.stride(0);
ck_tile::index_t batch_stride_k = k.stride(0);
ck_tile::index_t batch_stride_v = v.stride(0);
ck_tile::index_t batch_stride_o = out.stride(0);
ck_tile::index_t batch_stride_do = dout.stride(0);
ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
// d: (batch_size, nheads, seqlen_q)
// CK assume d share the same stride with lse

// dq: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_dq = dq.stride(0);
ck_tile::index_t stride_dq = dq.stride(1);
ck_tile::index_t nhead_stride_dq = dq.stride(2);

// dk_expanded: (batch_size, seqlen_k, nheads, hdim)
ck_tile::index_t batch_stride_dk = dk.stride(0);
ck_tile::index_t stride_dk = dk.stride(1);
ck_tile::index_t nhead_stride_dk = dk.stride(2);

// dv_expanded: (batch_size, seqlen_k, nheads, hdim)
ck_tile::index_t batch_stride_dv = dv.stride(0);
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);

float p_undrop = 1.0 - p_dropout;

Expand All @@ -96,6 +116,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
// alibi_slopes:(batch_size, nheads) or (nhead)
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

Expand All @@ -112,6 +133,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_k_ptr
Expand All @@ -132,6 +154,8 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
stride_o,
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
Expand All @@ -143,6 +167,10 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
Expand All @@ -152,15 +180,17 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
false, // s_randval
{drop_seed, drop_offset}};
}

Expand Down Expand Up @@ -224,7 +254,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
Expand Down Expand Up @@ -296,7 +326,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
// TODO - CK does not support dq_accum
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
Expand Down Expand Up @@ -326,10 +364,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num

if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq

auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);

auto args =
get_ck_fmha_bwd_args(
Expand All @@ -347,6 +384,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
out,
softmax_lse,
dout_padded,
dq_accum,
softmax_d,
dq,
dk_expanded,
Expand All @@ -356,7 +394,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
drop_seed,
drop_offset);

fmha_bwd(traits, args, stream_config);
float t = fmha_bwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
Expand Down
Loading