-
Notifications
You must be signed in to change notification settings - Fork 333
[Bugfix][CI] Bug fixing and migrate CI from ada to hopper #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @xwhzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses several CI stability issues specific to the Hopper architecture, primarily focusing on improving the robustness of Tensor Memory Accelerator (TMA) operations and refining the conditions under which Warp Group Matrix Multiply Accumulate (WGMMA) is applied. The changes aim to prevent internal errors, resolve numerical mismatches, and ensure correct code generation for Hopper-specific optimizations, contributing to a more stable CI pipeline.
Highlights
- TMA Robustness: Implemented stricter checks and graceful fallbacks for Tensor Memory Accelerator (TMA) bulk copy operations in
src/op/bulk_copy.ccto prevent errors related to unsupported strides and layouts. This includes checks for global stride alignment and size, and replacing hardICHECK(0)failures withreturn Stmt()for better error handling. - WGMMA Precision Control: Introduced a
CheckWGMMAfunction insrc/op/gemm.ccandsrc/op/gemm.hto precisely determine when Warp Group Matrix Multiply Accumulate (WGMMA) can be utilized. This function evaluates input/output data types (e.g., FP16, BF16, FP32, NVFloat8, Int8) and K-dimension alignment, ensuring numerical correctness and proper application of Hopper-specific optimizations. TheGemm::LowerandGemm::InferLayoutmethods now leverage this check. - Dynamic TMA Lowering: Added logic in
src/transform/lower_tile_op.ccto dynamically set thekDisableTMALowercontext configuration based on whether TMA operations (tma_load,tma_store, etc.) are actually present in the code. This allows for more optimized compilation by only running TMA-specific lowering passes when necessary. - Buffer Access Accuracy: Enhanced buffer access region analysis in
src/transform/warp_specialized_rewriter.ccto correctly handle remapped buffers. TheWSCodeEmitternow checks thebuffer_data_to_buffer_map when populating read and write sets, improving the accuracy of buffer tracking during warp specialization.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces bug fixes for Hopper architecture support, focusing on TMA and WGMMA operations. The changes add checks to prevent crashes and fallback mechanisms. The review focuses on correctness, maintainability, and C++ best practices, addressing potential null pointer dereferences, code complexity, and duplication.
| bool Gemm::CheckWGMMA() const { | ||
| if (C->dtype == DataType::Float(16)) { | ||
| if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E4M3() && | ||
| B->dtype == DataType::NVFloat8E4M3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E4M3() && | ||
| B->dtype == DataType::NVFloat8E5M2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E5M2() && | ||
| B->dtype == DataType::NVFloat8E4M3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E5M2() && | ||
| B->dtype == DataType::NVFloat8E5M2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else | ||
| return false; | ||
| } else if (C->dtype == DataType::Float(32)) { | ||
| if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype == DataType::BFloat(16) && | ||
| B->dtype == DataType::BFloat(16)) | ||
| return K % 16 == 0; | ||
| else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) | ||
| return (!trans_A) && trans_B && K % 8 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E4M3() && | ||
| B->dtype == DataType::NVFloat8E4M3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E4M3() && | ||
| B->dtype == DataType::NVFloat8E5M2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E5M2() && | ||
| B->dtype == DataType::NVFloat8E4M3()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::NVFloat8E5M2() && | ||
| B->dtype == DataType::NVFloat8E5M2()) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else | ||
| return false; | ||
| } else if (C->dtype == DataType::Int(32)) { | ||
| if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; | ||
| else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) | ||
| return (!trans_A) && trans_B && K % 32 == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CheckWGMMA function contains a lot of repetitive conditions, making it difficult to read and maintain. Refactor this function to group conditions by data type and factor out common logic.[^1]
bool Gemm::CheckWGMMA() const {
auto is_fp8 = [](const DataType& dtype) {
return dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2();
};
auto is_int8 = [](const DataType& dtype) {
return dtype == DataType::Int(8) || dtype == DataType::UInt(8);
};
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) return K % 16 == 0;
if (is_fp8(A->dtype) && is_fp8(B->dtype)) return !trans_A && trans_B && K % 32 == 0;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) return K % 16 == 0;
if (A->dtype == DataType::BFloat(16) && B->dtype == DataType::BFloat(16)) return K % 16 == 0;
if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) return !trans_A && trans_B && K % 8 == 0;
if (is_fp8(A->dtype) && is_fp8(B->dtype)) return !trans_A && trans_B && K % 32 == 0;
} else if (C->dtype == DataType::Int(32)) {
if (is_int8(A->dtype) && is_int8(B->dtype)) return !trans_A && trans_B && K % 32 == 0;
}
return false;
}
src/op/bulk_copy.cc
Outdated
|
|
||
| for (size_t i{1}; i < desc.global_stride.size(); i++) { | ||
| unsigned long long stride = desc.global_stride[i].as<IntImmNode>()->value; | ||
| if (stride % 16 != 0 || stride >= (1ULL << 40)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/op/bulk_copy.cc
Outdated
|
|
||
| for (size_t i{1}; i < desc.global_stride.size(); i++) { | ||
| unsigned long long stride = desc.global_stride[i].as<IntImmNode>()->value; | ||
| if (stride % 16 != 0 || stride >= (1ULL << 40)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && | ||
| inner_box_dim_ % 256 != 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding braces {} to the if statement for better readability and to avoid potential dangling-else issues.
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && | |
| inner_box_dim_ % 256 != 0) | |
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && inner_box_dim_ % 256 != 0) { | |
| return Stmt(); | |
| } |
|
|
||
| int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); | ||
|
|
||
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 256 should be replaced with a named constant to improve readability and maintainability.
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && | |
| constexpr int kInnerBoxDimThreshold = 256; | |
| if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE) && inner_box_dim_ % kInnerBoxDimThreshold != 0) |
| for (auto region : access[0]) { | ||
| auto var = region->buffer->data; | ||
| if (buffer_data_to_buffer_.count(var)) { | ||
| read_set.insert(buffer_data_to_buffer_[var].get()); | ||
| } else { | ||
| read_set.insert(region->buffer.get()); | ||
| } | ||
| } | ||
| for (auto region : access[1]) { | ||
| auto var = region->buffer->data; | ||
| if (buffer_data_to_buffer_.count(var)) { | ||
| write_set.insert(buffer_data_to_buffer_[var].get()); | ||
| } else { | ||
| write_set.insert(region->buffer.get()); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for populating read_set and write_set is duplicated. Extract the common logic into a helper function or lambda to improve code maintainability.[^1]
| for (auto region : access[0]) { | |
| auto var = region->buffer->data; | |
| if (buffer_data_to_buffer_.count(var)) { | |
| read_set.insert(buffer_data_to_buffer_[var].get()); | |
| } else { | |
| read_set.insert(region->buffer.get()); | |
| } | |
| } | |
| for (auto region : access[1]) { | |
| auto var = region->buffer->data; | |
| if (buffer_data_to_buffer_.count(var)) { | |
| write_set.insert(buffer_data_to_buffer_[var].get()); | |
| } else { | |
| write_set.insert(region->buffer.get()); | |
| } | |
| } | |
| auto populate_access_set = [&](std::set<const BufferNode*>& set, const Array<BufferRegion>& regions) { | |
| for (auto region : regions) { | |
| auto it = buffer_data_to_buffer_.find(region->buffer->data); | |
| if (it != buffer_data_to_buffer_.end()) { | |
| set.insert(it->second.get()); | |
| } else { | |
| set.insert(region->buffer.get()); | |
| } | |
| } | |
| }; | |
| populate_access_set(read_set, access[0]); | |
| populate_access_set(write_set, access[1]); |
- Removed unnecessary blank lines for improved code readability. - Enhanced stride validation by checking for null pointers in global stride calculations, ensuring robustness against symbolic strides. - Updated pass configuration handling in dynamic tile language tests to streamline dynamic alignment and TMA lower pass settings.
- Downgraded `flash-attn` dependency version in `requirements-test.txt` to `<=2.2.0`. - Removed unused imports and commented-out code in various example files to enhance readability and maintainability. - Updated the `flashattn` function signature to include default parameters for `block_M`, `block_N`, `num_stages`, and `threads`. - Cleaned up the `example_mha_fwd_varlen.py` and `example_mha_bwd_wgmma_pipelined.py` files by removing unnecessary comments and improving code clarity. - Deleted the `example_mha_inference.py` file as it is no longer needed.
- Removed the `--user` flag from the pip install commands in both the development and testing sections of the CI workflow to ensure proper installation of dependencies in the virtual environment.
- Added the `--no-user` flag to the pip install commands in both the development and testing sections of the CI workflow to ensure dependencies are installed correctly within the virtual environment.
… for wheel mode - Added the `--no-user` flag to the pip install command in the wheel mode section of the CI workflow to ensure dependencies are installed correctly within the virtual environment.
|
Phase out test_tilelang_issue_101.py as it has never been correct. |
* fix CI bugs in hopper * lint fix * Update bulk_copy.cc * Refactor bulk copy logic in LowerBulkCopy function - Removed unnecessary blank lines for improved code readability. - Enhanced stride validation by checking for null pointers in global stride calculations, ensuring robustness against symbolic strides. - Updated pass configuration handling in dynamic tile language tests to streamline dynamic alignment and TMA lower pass settings. * test fix * ci fix * Update flash-attention dependencies and clean up example code - Downgraded `flash-attn` dependency version in `requirements-test.txt` to `<=2.2.0`. - Removed unused imports and commented-out code in various example files to enhance readability and maintainability. - Updated the `flashattn` function signature to include default parameters for `block_M`, `block_N`, `num_stages`, and `threads`. - Cleaned up the `example_mha_fwd_varlen.py` and `example_mha_bwd_wgmma_pipelined.py` files by removing unnecessary comments and improving code clarity. - Deleted the `example_mha_inference.py` file as it is no longer needed. * Update CI workflow to remove `--user` flag from pip install commands - Removed the `--user` flag from the pip install commands in both the development and testing sections of the CI workflow to ensure proper installation of dependencies in the virtual environment. * Update CI workflow to include `--no-user` flag in pip install commands - Added the `--no-user` flag to the pip install commands in both the development and testing sections of the CI workflow to ensure dependencies are installed correctly within the virtual environment. * Update CI workflow to include `--no-user` flag in pip install command for wheel mode - Added the `--no-user` flag to the pip install command in the wheel mode section of the CI workflow to ensure dependencies are installed correctly within the virtual environment. * test fix * avoid conflict with system environments * test fix * add commnets --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
This PR is intended to fix #506 ci bugs in hopper, including the following issues:
Also fix #479
test_tilelang_issue_101.pytest_tilelang_kernel_deepseek_nsa.pytest_tilelang_kernel_dequantize_gemm.py