-
Notifications
You must be signed in to change notification settings - Fork 331
[Refactor] Introduce GemmInst for different targets handling #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… and target instruction handling - Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture. - Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic. - Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture. - Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've refactored the GEMM (General Matrix Multiply) operation handling to introduce a more robust and flexible mechanism for selecting target-specific GEMM instructions. This change centralizes the logic for determining the appropriate instruction set based on the target architecture and block size, moving away from scattered boolean flags. The core impact is improved clarity and maintainability of the codebase, particularly in how warp partitioning, layout inference, and lowering processes adapt to different GPU architectures.
Highlights
- Centralized GEMM Instruction Selection: I introduced a new
GemmInstenum and aGetGemmInstmethod to centralize the logic for determining the appropriate GEMM instruction (e.g., MMA, WGMMA, MFMA) based on the target architecture and block size. - Enhanced Warp Partitioning Logic: I updated the
ComputeWarpPartitionmethod to directly accept the newGemmInsttype, making its behavior more explicit and flexible across different GEMM instruction sets. - New Warp Size Utility: I added a new utility function,
TargetGetWarpSize, to consistently retrieve the warp size for a given target architecture, simplifying warp size calculations throughout the codebase. - Streamlined Layout Inference and Lowering: I refactored the
LowerandInferLayoutmethods to leverage the newGetGemmInstandTargetGetWarpSizeutilities, which significantly simplifies the conditional logic for handling different target-specific GEMM implementations and improves code maintainability.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the GEMM operator to better handle different target architectures by introducing a GemmInst enum and centralizing logic for determining GEMM instructions and warp sizes. The changes improve code clarity and maintainability. My review has identified two critical issues where the refactoring has led to incorrect logic: one in determining whether to use WGMMA, and another in calculating warp partitions. These issues could lead to incorrect code generation or runtime errors and should be addressed.
src/op/gemm.cc
Outdated
| bool allow_wgmma = | ||
| TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactoring to introduce GetGemmInst seems to have missed the CheckWGMMA() call. Previously, the decision to use WGMMA depended on CheckWGMMA(), which verifies data types and transpose conditions. Without this check, the compiler might incorrectly select WGMMA for unsupported configurations, leading to compilation errors or incorrect results.
| bool allow_wgmma = | |
| TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0); | |
| bool allow_wgmma = | |
| TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && CheckWGMMA(); |
src/op/gemm.cc
Outdated
| auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target); | ||
| int warp_size = TargetGetWarpSize(T.target); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ComputeWarpPartition function expects num_warps as its first argument, which should be block_size / warp_size. However, block_size is being passed instead. This will lead to incorrect warp partitioning. Additionally, warp_size is computed after it's needed for this call. Please reorder the statements and pass the correct argument.
| auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target); | |
| int warp_size = TargetGetWarpSize(T.target); | |
| int warp_size = TargetGetWarpSize(T.target); | |
| auto [warp_m, warp_n] = ComputeWarpPartition(block_size / warp_size, gemm_inst, T.target); |
…lelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities.
* [Enhancement] Disable cache and append git commit ID to version in tilelang (#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart
* [Enhancement] Disable cache and append git commit ID to version in tilelang (#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart * optimize code.
…#688) * [Enhancement] Refactor GEMM operations for improved warp partitioning and target instruction handling - Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture. - Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic. - Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture. - Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase. * bug fix * test fix * lint fix
…e-ai#697) * [Enhancement] Disable cache and append git commit ID to version in tilelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart
* [Enhancement] Disable cache and append git commit ID to version in tilelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart * optimize code.
GetGemmInstmethod to determine the appropriate GEMM instruction based on block size and target architecture.ComputeWarpPartitionto accept the GEMM instruction type, enhancing flexibility in warp partitioning logic.TargetGetWarpSizeutility to streamline warp size retrieval based on target architecture.