Skip to content

Conversation

@LeiWang1999
Copy link
Member

as title, align the schedule from official flashmla repo.

…lelang (tile-ai#688)

* Disabled caching in quickstart example for improved performance.
* Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities.
@LeiWang1999 LeiWang1999 merged commit a1149ca into tile-ai:main Aug 6, 2025
2 of 3 checks passed
@github-actions
Copy link

github-actions bot commented Aug 6, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the flashmla example by aligning its kernel scheduling with the official flashmla repository. The changes primarily involve refining shared memory allocations, streamlining synchronization barriers, and adjusting GEMM call policies to improve performance and resource utilization within the GPU kernel.

Highlights

  • Memory Optimization & Data Flow: Refactored shared memory usage by replacing generic S_shared and S_shared_ with more specific SP0_shared and SP1_shared (reusing Q_pe_shared). Introduced explicit fragment copies (Q_pe_local_0/1, acc_s_0/1_cast) to optimize data movement and access patterns within the kernel.
  • Kernel Performance Tuning: Removed the policy=T.GemmWarpPolicy.FullCol argument from several T.gemm calls, indicating a potential default or an optimized internal handling of GEMM operations.
  • Synchronization Simplification: Streamlined barrier usage by removing several k_pe_shared_X_free_barrier and k_shared_X_l_free_barrier calls, which simplifies synchronization logic and potentially reduces overhead.
  • Code Readability: Added descriptive comments to shared memory allocations and barrier definitions, improving the clarity and maintainability of the kernel code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes optimize the warp specialize flashmla example by aligning the schedule from the official flashmla repo. A critical issue was identified concerning unsafe memory reuse that could lead to memory corruption under different parameter settings. A detailed explanation and a suggested fix have been provided.

SP0_shared = T.alloc_shared([block_H, block_N], dtype)

# smem_sP1 reuse Q_pe_shared
SP1_shared = Q_pe_shared
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This memory reuse optimization is unsafe and can lead to memory corruption. SP1_shared reuses the memory allocated for Q_pe_shared.

  • Q_pe_shared is allocated with size block_H * pe_dim.
  • SP1_shared is used to store data of size block_H * block_N (via T.copy(acc_s_1_cast, SP1_shared) on line 316, where acc_s_1_cast has shape [block_H, block_N]).

If pe_dim < block_N, this will result in an out-of-bounds write to shared memory. The default parameters in main() (pe_dim=64, BLOCK_N=64) happen to match, hiding this bug. This creates a fragile dependency between pe_dim and BLOCK_N.

To fix this, you could:

  1. (Recommended) Add an assert pe_dim == block_N at the beginning of the flashattn function to make this constraint explicit.
  2. (Alternative) Allocate SP1_shared separately to ensure it has the correct size, as suggested below. This is safer but will increase shared memory usage.
SP1_shared = T.alloc_shared([block_H, block_N], dtype)

RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Enhancement] Disable cache and append git commit ID to version in tilelang (tile-ai#688)

* Disabled caching in quickstart example for improved performance.
* Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities.

* revert quickstart

* optimize code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant