Commit be9abf1
authored
[Dev][Benchmark] Add MLA paged decoding example and benchmark script (tile-ai#158)
* [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16
- Remove redundant `acc_s_0` fragment in flash attention kernel
- Simplify memory copy and reduction operations
- Reorder memory copy and scaling steps for improved performance
- Add Hopper-specific synchronization method in CUDA reduce template
- Update reduce operation to use architecture-specific synchronization
* [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script
- Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script
- Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton
- Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations
- Implement performance comparison and CSV output for detailed performance analysis
- Add command-line argument support for targeted benchmarking and comparison
* [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision
- Replace `d` parameter with `dv` to clarify value dimension in MLA decoding
- Enhance block distribution logic for split KV processing
- Improve handling of remaining blocks in split KV computation
- Add initialization of `lse_max_local` to prevent potential precision issues
- Optimize block start and range calculations for more accurate sequence processing
* lint1 parent 3c53297 commit be9abf1
File tree
6 files changed
+956
-30
lines changed- examples/deepseek_mla
- testing/python/autotune
- tilelang
- autotuner
- engine
6 files changed
+956
-30
lines changed
0 commit comments