Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 18, 2024

In our initial design we use CUDA cooperative kernels and grid synchronization feature for cross threadblock reduction. Although it's slightly faster for multi-head attention without grouping (GQA), we found there are two issues with this implementation:

  1. The kernel scheduling for cross-threadblock merging (the code after grid.sync) is sub-optimal, the merging time would be the bottleneck when the number of chunks to merge is huge (e.g. in GQA).
  2. Not all hardware has grid synchronization feature (only NVIDIA & AMD GPUs have such features AFAIK), which makes FlashInfer implementation hard to generalize to other GPUs such as Metal.

In this PR we stop using CUDA cooperative kernels features and using the combo two kernels (one for decode, another for merge) instead, it harms performance a little bit for some shapes but I believe it's beneficial for longer-term development and maintainance.

@yzh119 yzh119 merged commit 29733e7 into main Jan 18, 2024
@MasterJH5574 MasterJH5574 deleted the batch_decode_var_len_merge_states branch January 18, 2024 15:22
yzh119 added a commit that referenced this pull request Feb 17, 2025
#804 didn't implement split-k, which might result in performance
degradation if concurrency is not large enough. This PR fixes issue.

We implemented the v2 scheduler and write-through optimization mentioned
in [our paper](https://arxiv.org/pdf/2501.01005) (section 3.3 and
appendix in D.2) for load-balancing.

In an early PR (#72), we
turned off `cudaLaunchCooperativeKernels` and `grid.sync()` because we
are not sure whether it's compatible with CUDAGraph. This PR adds them
back again for grid synchronization, to save some kernel launch
overhead.

## Benchmark

On H100 SXM5 80GB (3352 GB/s), this PR:
```
Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 22.33 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 330.72 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 638.73 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 1188.90 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 40.74 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 592.77 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 1112.83 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 1506.01 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 72.53 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 1007.80 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 1438.99 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1730.62 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 120.74 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 1340.86 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 1689.36 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1901.26 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 177.94 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 1619.51 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 1876.50 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 2010.58 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 231.70 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 1835.16 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 1997.24 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 2067.99 GB/s
```

Before this PR:
```
Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 15.46 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 238.49 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 472.44 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 929.12 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 15.47 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 250.71 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 500.21 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 996.37 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 16.36 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 257.59 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 515.88 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1035.55 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 16.37 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 261.47 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 524.76 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1054.54 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 16.50 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 263.69 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 528.89 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 1064.87 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 16.45 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 264.66 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 530.87 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 1070.93 GB/s
```
diptorupd pushed a commit to ROCm/flashinfer that referenced this pull request Sep 29, 2025
This PR fixes a bug in decode.cuh. 

In the `test_single_decode.cpp` file, we were only comparing results if
`DTypeQO` was `__half`. This was due to an incorrect if condition inside
the tester.

This PR also makes changes to `decode.cuh`. For ther CDNA3 architecture,
we have disabled using shared_memory for now. All `seq_len` now use the
naive implementation.

Tested using C++ and `examples/test_batch_decode_example.py`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants