Skip to content

Commit be9abf1

Browse files
authored
[Dev][Benchmark] Add MLA paged decoding example and benchmark script (tile-ai#158)
* [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16 - Remove redundant `acc_s_0` fragment in flash attention kernel - Simplify memory copy and reduction operations - Reorder memory copy and scaling steps for improved performance - Add Hopper-specific synchronization method in CUDA reduce template - Update reduce operation to use architecture-specific synchronization * [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script - Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script - Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton - Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations - Implement performance comparison and CSV output for detailed performance analysis - Add command-line argument support for targeted benchmarking and comparison * [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision - Replace `d` parameter with `dv` to clarify value dimension in MLA decoding - Enhance block distribution logic for split KV processing - Improve handling of remaining blocks in split KV computation - Add initialization of `lse_max_local` to prevent potential precision issues - Optimize block start and range calculations for more accurate sequence processing * lint
1 parent 3c53297 commit be9abf1

File tree

6 files changed

+956
-30
lines changed

6 files changed

+956
-30
lines changed

0 commit comments

Comments
 (0)