Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e90b991
[ROCm] manually pick up fwd native padding support from Meekail's PR
wangye805 Oct 16, 2025
9d02d52
Initial update
Micky774 Oct 16, 2025
81bac35
Updated stride
Micky774 Oct 16, 2025
54ee86a
Corrected typing in allocation portions
Micky774 Oct 16, 2025
47a7cab
Applied Ye's patch
Micky774 Oct 17, 2025
0e0064f
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
945ab5b
[ROCm] jax use runtime segment
wangye805 Oct 21, 2025
579b592
[ROCm] get runtime max_seqlen as well
wangye805 Oct 22, 2025
73247d9
[ROCm] support v2 bwd native padding
wangye805 Oct 22, 2025
7e1c3ef
Updated conversion to include bwd pass
Micky774 Oct 22, 2025
51090d3
Merge branch 'yewang12/te_aiter_native_padding_bwd' into zain/aiter-b…
Micky774 Oct 23, 2025
0e121ba
Added BWD BSHD-->THD conversion and minor logic refactor
Micky774 Oct 23, 2025
734692d
Corrected softmax lse bug
Micky774 Oct 23, 2025
5c24188
Updated logic flow and re-caclulation
Micky774 Oct 23, 2025
b59d466
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
97073fe
Merge branch 'zain/aiter-bwd-bshd-thd' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
f27a99f
Added env var guard
Micky774 Oct 28, 2025
d757aef
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
33c5912
Updated ptr variables and streamlined dispatch
Micky774 Oct 29, 2025
af57290
Added env guard
Micky774 Oct 29, 2025
bc8f4a7
Corrected bshd_to_thd conversion arguments
Micky774 Oct 29, 2025
b7f2cf8
Corrected logical flow
Micky774 Oct 30, 2025
3e48a02
Guarded memset and corrected allocation
Micky774 Nov 5, 2025
b1094c6
Remove V3 API check and guard memsets
Micky774 Nov 5, 2025
c3a0fce
PR comments
Micky774 Nov 6, 2025
9ab8df4
Updated documentation
Micky774 Nov 10, 2025
2adfb6e
PR review reconciliation
Micky774 Nov 10, 2025
bb3868d
Added explicit test
Micky774 Nov 12, 2025
52c8167
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 12, 2025
6206d58
Formatting for bwd debug
Micky774 Nov 13, 2025
0582851
Resolved error when using mixed formats e.g. sbhd_2bshd
Micky774 Nov 14, 2025
78716de
Updated guard on flash-attention forced support
Micky774 Nov 14, 2025
85bb6f6
Added check for SBHD_2BSHD
Micky774 Nov 14, 2025
a12105d
Added guard on dk/dv memset
Micky774 Nov 14, 2025
2edd3d4
Removed env var gating for dk/dv zero padding, formatting
Micky774 Nov 24, 2025
221f286
Added inline comment to test
Micky774 Nov 24, 2025
68d4faf
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 24, 2025
e84e385
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 25, 2025
1eb25ea
Corrected Softmax LSE buffer allocation
Micky774 Nov 25, 2025
6ecea1d
Correct Softmax LSE buffer memory allocation
Micky774 Dec 3, 2025
0abc5e4
Adjusted fwd pass softmax lse allocation
Micky774 Dec 3, 2025
ed64d0b
Adjusted bwd pass softmax conversion allocation
Micky774 Dec 3, 2025
46fe62b
Minor reversions
Micky774 Dec 3, 2025
b925c19
[ROCm] fix the aiter fwd v3 cu_seqlen/cu_seqlen_padded api issue
wangye805 Dec 15, 2025
f500863
Update README.rst to fix formatting
wangye805 Dec 15, 2025
6fb596a
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Dec 15, 2025
19f2046
Merge remote-tracking branch 'origin/dev' into zain/aiter-native-bshd…
wangye805 Dec 16, 2025
cc37cb2
[ROCm] update aiter commit with swa fix
wangye805 Jan 12, 2026
871cb4e
Merge remote-tracking branch 'origin/dev' into zain/aiter-native-bshd…
wangye805 Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1167 files
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
for both the `FusedAttention` and `DotProductAttention` modules.

Certain settings can be enabled to potentially optimize workloads depending on the nature of the inputs and expected outputs:

* NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime. This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding.

AITER FA v3 Kernels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
Expand Down
23 changes: 23 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,29 @@ def test():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

# TODO: Enable config support in other backend(s) -- currently only the CK
# backend is capable of supporting it.
@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_gqa_mla_thd():
"""
Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration
post-processing for BWD FA with native padding support.
"""
config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64)
qkv_layout = "thd_thd_thd"
dtype = torch.float16
_, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=True,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test requires the CK fused attention backend.")

test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False)

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_dot_product_mem_calc():
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd(
uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_varlen_fwd(
Expand All @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
bool is_training,
float scaling_factor,
float dropout_probability,
Expand All @@ -82,6 +84,7 @@ hipError_t ck_attn_varlen_fwd(
uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_thd_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_bwd(
Expand Down Expand Up @@ -137,6 +140,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down
132 changes: 99 additions & 33 deletions transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@

namespace ck_fused_attn{

// TODO: unify with binary search in TE/common/fused_attn(rocm)/util
// no device std::upper_bound
// in an increasing array with given size len, search for the index that:
// array[index] <= target < array[index+1]
// guaranteed that target >=0 and target <= cu_seqlen[end-1]
__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}

// define dk_dv_reduce function only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce(
Expand Down Expand Up @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce(
// define dk_dv_reduce function in THD layout only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_expanded,
const DataType *dv_expanded,
uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded,
Expand All @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd(
uint64_t hdim_idx = threadIdx.x;

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd(
// When d_qk != d_v, we need to reduce dk and dv separately
template<typename DataType>
__global__ void dk_or_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_or_dv_expanded,
uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded,
DataType *dk_or_dv,
Expand All @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd(

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -323,7 +355,7 @@ void log_bwd_config(const char* func_name,
std::cout<<std::endl<<func_name<<std::endl;

// fmha_traits debug
std::cout<<"fmha_traits: "<<std::endl;
std::cout<<std::endl<<"fmha_traits: "<<std::endl;
std::cout<<"hdim_q: "<<fmha_args.hdim_q<<std::endl;
std::cout<<"hdim_v: "<<fmha_args.hdim_v<<std::endl;
std::cout<<"data_type: "<<data_type_str<<std::endl;
Expand All @@ -339,7 +371,7 @@ void log_bwd_config(const char* func_name,
std::cout<<"how_v3_bf16_cvt: "<<how_v3_bf16_cvt<<std::endl;

// fmha_args debug
std::cout<<"fmha_args: "<<std::endl;
std::cout<<std::endl<<"fmha_args: "<<std::endl;
std::cout<<"q_ptr: "<<fmha_args.q_ptr<<std::endl;
std::cout<<"k_ptr: "<<fmha_args.k_ptr<<std::endl;
std::cout<<"v_ptr: "<<fmha_args.v_ptr<<std::endl;
Expand All @@ -353,9 +385,15 @@ void log_bwd_config(const char* func_name,
std::cout<<"dk_ptr: "<<fmha_args.dk_ptr<<std::endl;
std::cout<<"dv_ptr: "<<fmha_args.dv_ptr<<std::endl;
std::cout<<"dbias_ptr: "<<fmha_args.dbias_ptr<<std::endl;
std::cout<<"dq_acc_ptr: "<<fmha_args.dq_acc_ptr<<std::endl;

std::cout<<"seqstart_q_ptr: "<<fmha_args.seqstart_q_ptr<<std::endl;
std::cout<<"seqstart_k_ptr: "<<fmha_args.seqstart_k_ptr<<std::endl;
std::cout<<"seqlen_q_ptr: "<<fmha_args.seqlen_q_ptr<<std::endl;
std::cout<<"seqlen_k_ptr: "<<fmha_args.seqlen_k_ptr<<std::endl;
std::cout<<"cu_seqlen_q_ptr: "<<fmha_args.cu_seqlen_q_ptr<<std::endl;
std::cout<<"cu_seqlen_k_ptr: "<<fmha_args.cu_seqlen_k_ptr<<std::endl;

std::cout<<"seqlen_q: "<<fmha_args.seqlen_q<<std::endl;
std::cout<<"seqlen_k: "<<fmha_args.seqlen_k<<std::endl;
std::cout<<"batch: "<<fmha_args.batch<<std::endl;
Expand Down Expand Up @@ -572,9 +610,12 @@ hipError_t ck_attn_bwd(
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr,
dq_acc_ptr, //dq_acc_buf
nullptr,//cu_seqlen_q
nullptr,//cu_seqlen_kv
nullptr,//seqstart_q_ptr
nullptr,//seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
nullptr, //cu_seqlen_q_ptr
nullptr, //cu_seqlen_k_ptr
shape_seqlen_q,
shape_seqlen_k,
batch,
Expand Down Expand Up @@ -780,6 +821,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down Expand Up @@ -911,11 +953,14 @@ hipError_t ck_attn_varlen_bwd(
dq_ptr,
is_mqa_gqa? dk_expanded_ptr:dk_ptr,
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
nullptr,
nullptr, //dbias_ptr
dq_acc_ptr, //dq_acc_buf
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr
cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
cu_seqlen_q_ptr, //cu_seqlen_q_ptr
cu_seqlen_kv_ptr, //cu_seqlen_k_ptr
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_k, //seqlen_kv, unused in group mode
batch,
Expand Down Expand Up @@ -973,21 +1018,33 @@ hipError_t ck_attn_varlen_bwd(
std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}};
}();

// modify the max_seqlen_q for better performance in 0-length cases
// lse_thd_ptr used as buffer
if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) {
if(std::string(env_p) == "1"){
if(ck_fused_attn_log_config){
std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.";
}
fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream);
fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream);
}
}

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
if(average_runtime < 0){
//TODO: better error out system
throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass.");
Expand All @@ -998,6 +1055,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_dv_reduce_thd: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dkv_expanded: "<<stride_h_dk_expanded<<std::endl;
Expand All @@ -1010,8 +1069,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_dv_reduce_thd<CK_TILE_TYPE>, grid, block, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
Expand All @@ -1022,6 +1082,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dk(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dk: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"stride_h_dk_expanded: "<<stride_h_dk_expanded<<std::endl;
std::cout<<"stride_s_dk_expanded: "<<stride_s_dk_expanded<<std::endl;
Expand All @@ -1032,8 +1094,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
static_cast<CK_TILE_TYPE*>(dk_ptr),
Expand All @@ -1042,6 +1105,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dv(d_v);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dv: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dv_expanded: "<<stride_h_dv_expanded<<std::endl;
std::cout<<"stride_s_dv_expanded: "<<stride_s_dv_expanded<<std::endl;
Expand All @@ -1052,8 +1117,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dv, 0, stream,
h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dv_expanded, stride_s_dv_expanded,
static_cast<CK_TILE_TYPE*>(dv_ptr),
Expand Down
Loading
Loading