-
Notifications
You must be signed in to change notification settings - Fork 332
[Example] Optimize warp specialize flashmla example #698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…lelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on optimizing the flashmla example by aligning its kernel scheduling with the official flashmla repository. The changes primarily involve refining shared memory allocations, streamlining synchronization barriers, and adjusting GEMM call policies to improve performance and resource utilization within the GPU kernel.
Highlights
- Memory Optimization & Data Flow: Refactored shared memory usage by replacing generic
S_sharedandS_shared_with more specificSP0_sharedandSP1_shared(reusingQ_pe_shared). Introduced explicit fragment copies (Q_pe_local_0/1,acc_s_0/1_cast) to optimize data movement and access patterns within the kernel. - Kernel Performance Tuning: Removed the
policy=T.GemmWarpPolicy.FullColargument from severalT.gemmcalls, indicating a potential default or an optimized internal handling of GEMM operations. - Synchronization Simplification: Streamlined barrier usage by removing several
k_pe_shared_X_free_barrierandk_shared_X_l_free_barriercalls, which simplifies synchronization logic and potentially reduces overhead. - Code Readability: Added descriptive comments to shared memory allocations and barrier definitions, improving the clarity and maintainability of the kernel code.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code changes optimize the warp specialize flashmla example by aligning the schedule from the official flashmla repo. A critical issue was identified concerning unsafe memory reuse that could lead to memory corruption under different parameter settings. A detailed explanation and a suggested fix have been provided.
| SP0_shared = T.alloc_shared([block_H, block_N], dtype) | ||
|
|
||
| # smem_sP1 reuse Q_pe_shared | ||
| SP1_shared = Q_pe_shared |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This memory reuse optimization is unsafe and can lead to memory corruption. SP1_shared reuses the memory allocated for Q_pe_shared.
Q_pe_sharedis allocated with sizeblock_H * pe_dim.SP1_sharedis used to store data of sizeblock_H * block_N(viaT.copy(acc_s_1_cast, SP1_shared)on line 316, whereacc_s_1_casthas shape[block_H, block_N]).
If pe_dim < block_N, this will result in an out-of-bounds write to shared memory. The default parameters in main() (pe_dim=64, BLOCK_N=64) happen to match, hiding this bug. This creates a fragile dependency between pe_dim and BLOCK_N.
To fix this, you could:
- (Recommended) Add an
assert pe_dim == block_Nat the beginning of theflashattnfunction to make this constraint explicit. - (Alternative) Allocate
SP1_sharedseparately to ensure it has the correct size, as suggested below. This is safer but will increase shared memory usage.
SP1_shared = T.alloc_shared([block_H, block_N], dtype)* [Enhancement] Disable cache and append git commit ID to version in tilelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart * optimize code.
as title, align the schedule from official flashmla repo.