-
Notifications
You must be signed in to change notification settings - Fork 558
[Refactor] Use two kernels instead of CUDA cooperative kernel for batch/single decode #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
yzh119
added a commit
that referenced
this pull request
Feb 17, 2025
#804 didn't implement split-k, which might result in performance degradation if concurrency is not large enough. This PR fixes issue. We implemented the v2 scheduler and write-through optimization mentioned in [our paper](https://arxiv.org/pdf/2501.01005) (section 3.3 and appendix in D.2) for load-balancing. In an early PR (#72), we turned off `cudaLaunchCooperativeKernels` and `grid.sync()` because we are not sure whether it's compatible with CUDAGraph. This PR adds them back again for grid synchronization, to save some kernel launch overhead. ## Benchmark On H100 SXM5 80GB (3352 GB/s), this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 22.33 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 330.72 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 638.73 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 1188.90 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 40.74 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 592.77 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 1112.83 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 1506.01 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 72.53 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 1007.80 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 1438.99 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1730.62 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 120.74 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 1340.86 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 1689.36 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1901.26 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 177.94 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 1619.51 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 1876.50 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 2010.58 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 231.70 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 1835.16 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 1997.24 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 2067.99 GB/s ``` Before this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 15.46 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 238.49 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 472.44 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 929.12 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 15.47 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 250.71 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 500.21 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 996.37 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 16.36 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 257.59 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 515.88 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1035.55 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 16.37 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 261.47 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 524.76 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1054.54 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 16.50 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 263.69 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 528.89 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 1064.87 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 16.45 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 264.66 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 530.87 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 1070.93 GB/s ```
diptorupd
pushed a commit
to ROCm/flashinfer
that referenced
this pull request
Sep 29, 2025
This PR fixes a bug in decode.cuh. In the `test_single_decode.cpp` file, we were only comparing results if `DTypeQO` was `__half`. This was due to an incorrect if condition inside the tester. This PR also makes changes to `decode.cuh`. For ther CDNA3 architecture, we have disabled using shared_memory for now. All `seq_len` now use the naive implementation. Tested using C++ and `examples/test_batch_decode_example.py`
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In our initial design we use CUDA cooperative kernels and grid synchronization feature for cross threadblock reduction. Although it's slightly faster for multi-head attention without grouping (GQA), we found there are two issues with this implementation:
grid.sync) is sub-optimal, the merging time would be the bottleneck when the number of chunks to merge is huge (e.g. in GQA).In this PR we stop using CUDA cooperative kernels features and using the combo two kernels (one for decode, another for merge) instead, it harms performance a little bit for some shapes but I believe it's beneficial for longer-term development and maintainance.