-
Notifications
You must be signed in to change notification settings - Fork 332
[Example] Implememt FMHA Varlen Example #131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.
Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity
This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking
- Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks
- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module
…Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison
This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.
…arameter Updates - Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements - Increased default number of heads and selected blocks in TileLang NSA example - Added Ruff linter ignore comments to reference.py - Standardized function signatures and improved code readability across NSA implementations
- Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers
- Implement `next_power_of_2()` to calculate the next power of 2 for an integer - Add `cdiv()` function for ceiling division of integers
…plementation - Update flash attention kernel to support positional embeddings (PE) - Modify reference implementation to handle PE and group query attention - Increase default batch size and adjust benchmarking parameters - Improve kernel performance and readability - Add einops and torch operations for more flexible tensor manipulation
- Modify the example link for Flash MLA Decoding to point to the correct directory - Ensure accurate navigation to the DeepSeek MLA decoding example
This commit introduces several improvements: - Simplified native sparse attention kernel by inlining macro functions in example_tilelang_nsa.py - Enhanced error handling in loop_partition.cc with more informative error messages - Updated print.py to support multi-dimensional buffer printing - Improved torch_assert_close in testing/__init__.py with more detailed mismatch reporting - Reduced default absolute tolerance in torch comparison from 1e-3 to 1e-2 - Added shape validation and detailed mismatch information in tensor comparison
This commit introduces several code formatting and utility improvements: - Add Ruff linter ignore comment in example_tilelang_nsa.py - Enhance code readability in loop_partition.cc and lower_tile_op.cc with improved line breaks - Simplify print_flat_buffer_with_condition in print.py - Refactor torch_assert_close in testing/__init__.py with improved line formatting
This commit improves the print functionality in print.py by: - Adding support for printing fragment memory buffers - Implementing a new print_fragment_buffer_with_condition macro - Extending print_shared_buffer_with_condition for shared memory buffers - Updating the generic print function to handle different buffer scopes
Remove merge conflict marker and clean up whitespace in the print module
…ention Support Introduce a new example script `example_mha_fwd_varlen.py` that demonstrates: - Variable-length Multi-Head Attention (MHA) implementation - Flash Attention forward pass with padding mask support - Performance benchmarking for variable-length sequences - Configurable parameters for batch, heads, sequence length, and dimension - Reference implementation comparison with PyTorch and FlashAttention
Improve code formatting and readability in the variable-length multi-head attention example: - Add Ruff linter ignore comment - Enhance code style with consistent formatting - Remove unused imports - Improve line breaks and indentation - Simplify function signatures and lambda expressions
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request includes several changes to
examples/flash_attention/example_gqa_fwd_bshd.pyandtilelang/language/print.pyto improve functionality and enhance the codebase. The most important changes include updating argument defaults in the example script and adding new functions for printing buffers with conditions.Updates to argument defaults:
examples/flash_attention/example_gqa_fwd_bshd.py: Updated the default values forbatch,heads,seq_len, andgroupsarguments to better suit the new testing requirements.Enhancements to buffer printing:
tilelang/language/print.py: Added new importscopyandalloc_sharedto support shared memory operations.tilelang/language/print.py: Renamedprint_flat_buffer_with_conditiontoprint_shared_buffer_with_conditionto reflect its purpose more accurately.tilelang/language/print.py: Introduced a new macroprint_fragment_buffer_with_conditionto conditionally print the values of a flattened TIR buffer when the condition is true.tilelang/language/print.py: Enhanced theprintfunction to support printing fragment buffers by using the newprint_fragment_buffer_with_conditionmacro. [1] [2]