Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This pull request fixes sliding window attention with Multi-Token Processing (MTP) in the paged attention decode implementation, adding support for KV_BLOCK_SIZE=1024 and improving the sliding window causal masking logic.
Changes:
- Added support for KV_BLOCK_SIZE=1024 in sliding window kernels with appropriate page offset calculations and windowing masks
- Fixed causal masking for sliding window to correctly handle per-query-position windows
- Reorganized kernel code for better performance by moving initialization earlier and consolidating the PS path
- Reduced MAX_CONTEXT_PARTITION_NUM from 16 to 8 to avoid exceeding shared memory limits
- Expanded test coverage for sliding window scenarios
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| op_tests/triton_tests/test_pa_decode_gluon.py | Tightened diff tolerance from 8e-2 to 5e-2 and expanded test coverage with additional head dimensions, quantization modes, and configurations |
| aiter/ops/triton/gluon/pa_decode_gluon.py | Added KV_BLOCK_SIZE=1024 support with page offset handling, fixed sliding window causal masking, reorganized initialization code, reduced MAX_CONTEXT_PARTITION_NUM to 8, and moved PS kernel path to top of wrapper |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…kernel Co-authored-by: Cursor <cursoragent@cursor.com>
…o cache key in sliding window decode Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
KV_BLOCK_SIZE=1024is supported by the cache layout, but the PS (partitioned-softmax) decode path previously assumed smaller KV block sizes and could:block_size=1024sliding_window>0context_partition_numdue to Triton tensor size limitsThis PR makes PS decode robust for
KV_BLOCK_SIZE=1024and fixes PS reduction compilation/resource issues.Technical Details
1)
paged_attention_decode_sliding_window: addKV_BLOCK_SIZE=1024supportKV_BLOCK_SIZEin[16, 64, 1024].KV_BLOCK_SIZE==1024, treat the KV page as 4 tiles of 256 tokens:KV_COMPUTE_BLOCK_SIZE = CONTEXT_PARTITION_SIZE (256)page_offset ∈ {0, 256, 512, 768}and apply it to:stride_key_block_elemwhen stepping through KV elements to match the actual key cache layout.2) PS wrapper fixes
ONE_SHOT=(num_splits <= 1)intopaged_attention_decode_sliding_windowKV_BLOCK_SIZE==1024:waves_per_eu=1waves_per_eu=4num_stages=13) PS reduce kernel: avoid Triton
numellimit and shared memory overflowpaged_attention_decode_ps_reduce_kernelnow reduces partitions in chunks (two-pass reduction), instead of materializing tensors sized bynext_power_of_2(context_partition_num).<= 8partitions:ValueError('numel (...) exceeds triton maximum tensor numel (1048576)')qg=64, head=128).Test Plan
op_tests/triton_tests/test_pa_decode_gluon.py:block_size=1024,context_partition_size=256,kv_varlen=True,trans_v=Falsesliding_window=0andsliding_window=128batch_size=1andbatch_size=128block_size=16using same harness.Test Result
Submission Checklist