[TRITON] fix sink_attn error when causal=true#1837
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug in the Triton kernel implementation for sink attention backward pass when the causal flag is set to true. The issue involved incorrect block size calculations that were being divided by BLK_SLICE_FACTOR, which caused errors in the masked operations.
Changes:
- Corrected mask block size calculations in the backward causal kernel by removing unnecessary division by
BLK_SLICE_FACTOR
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 | ||
|
|
||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | ||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 |
There was a problem hiding this comment.
Corrected spelling of 'casual' to 'causal' in PR title and description. The PR metadata contains 'casual=true' which should be 'causal=true'.
There was a problem hiding this comment.
Hello @kyle-256. Thanks for your PR! My review follows.
UT failures
Can you please provide more details about the UT failures you've been facing? What Triton compiler are you using? I wasn't able to reproduce any UT failure with latest Triton and latest AITER. You can check details bellow.
Triton commit: triton-lang/triton@20251a3
AITER commit: da29487
Test results on MI300:
root@f799ed2bcfbf:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
MARKET_NAME: AMD Instinct MI300X
root@f799ed2bcfbf:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected
op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................ [100%]
=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
/workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= 192 passed, 9505 deselected, 4 warnings in 89.89s (0:01:29) ==========================
Test results on MI350:
root@ff4dcd1c1607:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
MARKET_NAME: AMD Instinct MI355X
root@ff4dcd1c1607:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected
op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................ [100%]
=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
/workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== 192 passed, 9505 deselected, 4 warnings in 44.28s ===============================
Changing BLK_SLICE_FACTOR may affect kernel performance
Can you please do some profiling on both MI300 and MI350, so we can be sure that changing BLK_SLICE_FACTOR to 1 doesn't introduce performance regressions? You can use op_tests/op_benchmarks/triton/bench_mha.py benchmark script to get performance data.
What are your target shapes?
| descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 | ||
|
|
||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | ||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 |
There was a problem hiding this comment.
I think the proper way of doing this change is by setting BLK_SLICE_FACTOR to 1 instead of removing BLK_SLICE_FACTOR. Please take a look at aiter/ops/triton/configs/gfx942-MHA-DEFAULT.json and aiter/ops/triton/configs/gfx950-MHA-DEFAULT.json config files (bkwd_onekernel → onekernel → BLK_SLICE_FACTOR).
There was a problem hiding this comment.
BLK_SLICE_FACTOR is a performance tuning parameter, it's important to keep it.
|
@kyle-256, I'm double checking performance. I'll post my results in the PR as soon as I get them. |
|
Sharing my benchmarking numbers: (all UTs passing with MI300
MI350
I used the following dirty script to get performance numbers: #!/usr/bin/env bash
mode='bwd'
dtype='bf16'
sq=8192
sk="${sq}"
d=64
causal='yes'
common_args="--metric time -mode ${mode} --dtype ${dtype} -sq ${sq} -sk ${sk} -d ${d} -causal ${causal}"
echo "tp,b,layout,time_ms"
for tp in 1 8; do
args="${common_args}"
if [[ "${tp}" -eq 1 ]]; then
hq=64
hk=8
elif [[ "${tp}" -eq 8 ]]; then
hq=8
hk=1
fi
args="${args} -hq ${hq} -hk ${hk}"
for layout in 'bshd' 'thd'; do
args="${args} --layout ${layout}"
if [[ "${layout}" == 'bshd' ]]; then
batch_sizes=(1)
elif [[ "${layout}" == 'thd' ]]; then
mapfile -t batch_sizes < <(seq 8 16)
fi
for b in "${batch_sizes[@]}"; do
args="${args} -b ${b}"
# shellcheck disable=SC2086
time_ms=$(python op_tests/op_benchmarks/triton/bench_mha.py ${args} \
2>&1 | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter ' ' --fields 7)
echo "${tp},${b},${layout},${time_ms}"
done
done
doneFeel free to do your own experiments and test your target shapes. |
Motivation
fix an error of sink attention backward when set casual=true