Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
else:
descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'casual' to 'causal' in PR title and description. The PR metadata contains 'casual=true' which should be 'causal=true'.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the proper way of doing this change is by setting BLK_SLICE_FACTOR to 1 instead of removing BLK_SLICE_FACTOR. Please take a look at aiter/ops/triton/configs/gfx942-MHA-DEFAULT.json and aiter/ops/triton/configs/gfx950-MHA-DEFAULT.json config files (bkwd_onekernelonekernelBLK_SLICE_FACTOR).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLK_SLICE_FACTOR is a performance tuning parameter, it's important to keep it.

# bound the masked operation to q len so it does not have to wast cycles
len_m = min(len_m, seqlen_q)
num_steps = tl.cdiv(len_m, MASK_BLOCK_M1)
Expand Down Expand Up @@ -1069,7 +1069,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
dsink = tl.sum(-psink * delta[:, None])
tl.atomic_add(DSink + hqid, dsink, sem="relaxed")

MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2
# start can only be 0 at minimum
start_n = max(end_n - BLOCK_M2, 0)
num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2)
Expand Down