-
Notifications
You must be signed in to change notification settings - Fork 333
Low-bit kernels fix and implementation #704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tzj-fxz
commented
Aug 8, 2025
- Fix the bug of the FP4 GEMM in hopper: exponential bias in convertion.
- Implement the BF16xMXFP4 GEMM kernel in hopper for the use of vllm community.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @tzj-fxz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've addressed a critical bug in the FP4 GEMM conversion on the Hopper architecture, specifically correcting an exponential bias issue. Concurrently, I've implemented a new BF16xMXFP4 GEMM kernel tailored for Hopper, which will be beneficial for the vLLM community. These changes enhance the accuracy and expand the capabilities of low-bit precision operations within the system.
Highlights
- FP4 GEMM Conversion Bug Fix: The
_tir_u8_to_f4_to_f16and_convertfunctions inexample_dequant_gemm_fp4_hopper.pywere updated to correct the exponential bias calculation for FP4 to FP16 conversion. Specifically, thee_f4calculation was changed fromf4 & 7to(f4 & 6) >> 1, ande_f16frome_f4 | 8toe_f4 + 14. Additionally, mantissa bits (m_f4andm_f16) are now correctly incorporated into theval_f16calculation, ensuring accurate floating-point representation. - BF16xMXFP4 GEMM Kernel Implementation: A new file,
example_dequant_gemm_mxfp4_hopper.py, has been introduced. This file implements a BF16xMXFP4 GEMM kernel, including a_tir_u8_to_f4_to_bf16function for converting 4-bit unsigned integers to bfloat16 with proper exponential bias and scaling. It also providesconvertandconvert_scalefunctions, and amatmulkernel for performing General Matrix Multiply (GEMM) operations with BF16xMXFP4 precision, supporting optional scaling. - New Test Case for BF16xMXFP4 GEMM: The
test_example_dequantize_gemm.pyfile was modified to include a new test function,test_example_dequant_gemm_mxfp4_hopper, which validates the correctness of the newly implemented BF16xMXFP4 GEMM kernel. This ensures the new low-bit precision operations function as expected.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a bug in the FP4 to FP16 conversion for Hopper and introduces a new BF16xMXFP4 GEMM kernel example. The bug fix appears correct. The new example is a valuable addition, but I've identified a critical bug in shared memory allocation, along with several opportunities for improvement in code clarity, consistency, and performance. My review includes suggestions to address these points.
| Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype) | ||
| Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a critical bug in the allocation of Scale_shared and Scale_local. Their shapes are defined using block_M, but they should be using block_K. The Scale tensor's dimensions are tied to N and K, and the copy operation T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared) expects a destination with a shape based on block_K. Using block_M will likely cause out-of-bounds memory access and incorrect results.
| Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype) | |
| Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype) | |
| Scale_shared = T.alloc_shared((block_N, block_K // scale_size), storage_dtype) | |
| Scale_local = T.alloc_fragment((block_N, block_K // scale_size), storage_dtype) |
| # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 | ||
| e_bf16 = e_f4 + tir.const(126, "uint16") | ||
| # Scale is the exponential part, within the representation of uint8 | ||
| # To handle the overflow, we use the max function to limit the exponential part to 8 bits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on this line is misleading. It states that a max function is used to handle overflow, but the implementation uses T.min to clamp the value. To avoid confusion, the comment should be updated to accurately reflect the code.
| # To handle the overflow, we use the max function to limit the exponential part to 8 bits | |
| # To handle the overflow, we use the min function to clamp the exponential part to 8 bits |
| def print_bit(name, val): | ||
| val_cpu = val.cpu().item() | ||
| binary_repr = f'{val_cpu:032b}' | ||
| print(name, binary_repr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| e_f16 = e_f4 + 126 | ||
| if scale is not None: | ||
| e_f16 = min(e_f16 + scale, (1 << 8) - 1) | ||
| m_f4 = f4 & 1 | ||
| m_f16 = m_f4 | ||
| val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF | ||
| lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) | ||
| return lower_16_bits.view(torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables e_f16 and val_f16 are used in the process of converting to bfloat16. This is confusing as the f16 suffix suggests float16. For better clarity and consistency with _tir_u8_to_f4_to_bf16, please rename them to e_bf16 and val_bf16 respectively.
| e_f16 = e_f4 + 126 | |
| if scale is not None: | |
| e_f16 = min(e_f16 + scale, (1 << 8) - 1) | |
| m_f4 = f4 & 1 | |
| m_f16 = m_f4 | |
| val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF | |
| lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16) | |
| return lower_16_bits.view(torch.bfloat16) | |
| e_bf16 = e_f4 + 126 | |
| if scale is not None: | |
| e_bf16 = min(e_bf16 + scale, (1 << 8) - 1) | |
| m_f4 = f4 & 1 | |
| m_f16 = m_f4 | |
| val_bf16 = (((e_bf16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF | |
| lower_16_bits = (val_bf16 & 0xFFFF).to(torch.uint16) | |
| return lower_16_bits.view(torch.bfloat16) |
| T.annotate_layout({ | ||
| B_shared: tilelang.layout.make_swizzled_layout(B_shared), | ||
| Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), | ||
| }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Scale_shared buffer is missing a layout annotation here, while it is present in the main function for the split=1 case. For consistency and to ensure optimal performance, you should add a swizzled layout for Scale_shared in this main_split function as well.
| T.annotate_layout({ | |
| B_shared: tilelang.layout.make_swizzled_layout(B_shared), | |
| Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), | |
| }) | |
| T.annotate_layout({ | |
| B_shared: tilelang.layout.make_swizzled_layout(B_shared), | |
| Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), | |
| Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared), | |
| }) |
* [MXFP4] Dequantize FP4 kernel example, MX scale todo * [BugFix] Fix the bug of fp4&fp16 exponential bias * [MXFP4] Add group scale factor for BF16xMXFP4 gemm * [Lint] * [Test] Add test script for BF16xMXFP4 gemm * [Lint] * [BugFix] Fix the shape of scale tensor * Update example_dequant_gemm_fp4_hopper.py --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>