Commit fbd28e5
authored
[Example] Implememt FMHA Varlen Example (#131)
* Add DeepSeek MLA decode example with Flash Attention implementation
* Add GEMM SplitK and StreamK example implementations
This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang
Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.
* Refactor GEMM SplitK and StreamK example implementations
Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity
* Add block sparse attention benchmarks for multiple libraries
This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation
The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking
* Refactor block sparse attention benchmarks with code style improvements
- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks
* lint fix
* Add CUDA atomic operations for BFLOAT16 and update function naming
- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module
* lint fix
* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang
This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison
* Refactor IR lowering pipeline into modular phases
This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations
The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.
* lintfix
* nas kernel
* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates
- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations
* Add utility math functions for integer operations
- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers
* Add utility math functions for integer operations
- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers
* Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation
- Update flash attention kernel to support positional embeddings (PE)
- Modify reference implementation to handle PE and group query attention
- Increase default batch size and adjust benchmarking parameters
- Improve kernel performance and readability
- Add einops and torch operations for more flexible tensor manipulation
* Update README.md with corrected Flash MLA Decoding example path
- Modify the example link for Flash MLA Decoding to point to the correct directory
- Ensure accurate navigation to the DeepSeek MLA decoding example
* Refactor Native Sparse Attention Kernel and Improve Utility Functions
This commit introduces several improvements:
- Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py
- Enhanced error handling in loop_partition.cc with more informative error messages
- Updated print.py to support multi-dimensional buffer printing
- Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting
- Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2
- Added shape validation and detailed mismatch information in tensor comparison
* Refactor Code Formatting and Improve Utility Functions
This commit introduces several code formatting and utility improvements:
- Add Ruff linter ignore comment in example_tilelang_nsa.py
- Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks
- Simplify print_flat_buffer_with_condition in print.py
- Refactor torch_assert_close in testing/__init__.py with improved line formatting
* Enhance Buffer Printing Support for Fragment and Shared Memory Buffers
This commit improves the print functionality in print.py by:
- Adding support for printing fragment memory buffers
- Implementing a new print_fragment_buffer_with_condition macro
- Extending print_shared_buffer_with_condition for shared memory buffers
- Updating the generic print function to handle different buffer scopes
* Resolve merge conflict in print.py
Remove merge conflict marker and clean up whitespace in the print module
* Add Variable-Length Multi-Head Attention (MHA) Example with Flash Attention Support
Introduce a new example script `example_mha_fwd_varlen.py` that demonstrates:
- Variable-length Multi-Head Attention (MHA) implementation
- Flash Attention forward pass with padding mask support
- Performance benchmarking for variable-length sequences
- Configurable parameters for batch, heads, sequence length, and dimension
- Reference implementation comparison with PyTorch and FlashAttention
* Refactor Flash Attention Variable-Length MHA Example
Improve code formatting and readability in the variable-length multi-head attention example:
- Add Ruff linter ignore comment
- Enhance code style with consistent formatting
- Remove unused imports
- Improve line breaks and indentation
- Simplify function signatures and lambda expressions1 parent 36a3f7b commit fbd28e5
File tree
2 files changed
+460
-4
lines changed- examples/flash_attention
2 files changed
+460
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
204 | 204 | | |
205 | 205 | | |
206 | 206 | | |
207 | | - | |
208 | | - | |
209 | | - | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
210 | 210 | | |
211 | 211 | | |
212 | 212 | | |
213 | | - | |
| 213 | + | |
214 | 214 | | |
215 | 215 | | |
216 | 216 | | |
| |||
0 commit comments