Skip to content

Commit fbd28e5

Browse files
authored
[Example] Implememt FMHA Varlen Example (#131)
* Add DeepSeek MLA decode example with Flash Attention implementation * Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. * Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity * Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking * Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks * lint fix * Add CUDA atomic operations for BFLOAT16 and update function naming - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module * lint fix * Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison * Refactor IR lowering pipeline into modular phases This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability. * lintfix * nas kernel * Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates - Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements - Increased default number of heads and selected blocks in TileLang NSA example - Added Ruff linter ignore comments to reference.py - Standardized function signatures and improved code readability across NSA implementations * Add utility math functions for integer operations - Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers * Add utility math functions for integer operations - Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers * Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation - Update flash attention kernel to support positional embeddings (PE) - Modify reference implementation to handle PE and group query attention - Increase default batch size and adjust benchmarking parameters - Improve kernel performance and readability - Add einops and torch operations for more flexible tensor manipulation * Update README.md with corrected Flash MLA Decoding example path - Modify the example link for Flash MLA Decoding to point to the correct directory - Ensure accurate navigation to the DeepSeek MLA decoding example * Refactor Native Sparse Attention Kernel and Improve Utility Functions This commit introduces several improvements: - Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py - Enhanced error handling in loop_partition.cc with more informative error messages - Updated print.py to support multi-dimensional buffer printing - Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting - Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2 - Added shape validation and detailed mismatch information in tensor comparison * Refactor Code Formatting and Improve Utility Functions This commit introduces several code formatting and utility improvements: - Add Ruff linter ignore comment in example_tilelang_nsa.py - Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks - Simplify print_flat_buffer_with_condition in print.py - Refactor torch_assert_close in testing/__init__.py with improved line formatting * Enhance Buffer Printing Support for Fragment and Shared Memory Buffers This commit improves the print functionality in print.py by: - Adding support for printing fragment memory buffers - Implementing a new print_fragment_buffer_with_condition macro - Extending print_shared_buffer_with_condition for shared memory buffers - Updating the generic print function to handle different buffer scopes * Resolve merge conflict in print.py Remove merge conflict marker and clean up whitespace in the print module * Add Variable-Length Multi-Head Attention (MHA) Example with Flash Attention Support Introduce a new example script `example_mha_fwd_varlen.py` that demonstrates: - Variable-length Multi-Head Attention (MHA) implementation - Flash Attention forward pass with padding mask support - Performance benchmarking for variable-length sequences - Configurable parameters for batch, heads, sequence length, and dimension - Reference implementation comparison with PyTorch and FlashAttention * Refactor Flash Attention Variable-Length MHA Example Improve code formatting and readability in the variable-length multi-head attention example: - Add Ruff linter ignore comment - Enhance code style with consistent formatting - Remove unused imports - Improve line breaks and indentation - Simplify function signatures and lambda expressions
1 parent 36a3f7b commit fbd28e5

File tree

2 files changed

+460
-4
lines changed

2 files changed

+460
-4
lines changed

examples/flash_attention/example_gqa_fwd_bshd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,13 @@ def ref_program(Q, K, V, is_causal, groups=1):
204204

205205
if __name__ == "__main__":
206206
parser = argparse.ArgumentParser()
207-
parser.add_argument('--batch', type=int, default=8, help='batch size')
208-
parser.add_argument('--heads', type=int, default=32, help='heads')
209-
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
207+
parser.add_argument('--batch', type=int, default=1, help='batch size')
208+
parser.add_argument('--heads', type=int, default=64, help='heads')
209+
parser.add_argument('--seq_len', type=int, default=256, help='sequence length')
210210
parser.add_argument('--dim', type=int, default=128, help='dim')
211211
parser.add_argument('--is_causal', action='store_true', help='causal')
212212
parser.add_argument('--tune', action='store_true', help='tune configs')
213-
parser.add_argument('--groups', type=int, default=8, help='groups')
213+
parser.add_argument('--groups', type=int, default=16, help='groups')
214214
args = parser.parse_args()
215215
batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups
216216
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim

0 commit comments

Comments
 (0)