Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Aug 8, 2025

  • Fix the bug of the FP4 GEMM in hopper: exponential bias in convertion.
  • Implement the BF16xMXFP4 GEMM kernel in hopper for the use of vllm community.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 August 8, 2025 03:15
@tzj-fxz tzj-fxz self-assigned this Aug 8, 2025
@github-actions
Copy link

github-actions bot commented Aug 8, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tzj-fxz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've addressed a critical bug in the FP4 GEMM conversion on the Hopper architecture, specifically correcting an exponential bias issue. Concurrently, I've implemented a new BF16xMXFP4 GEMM kernel tailored for Hopper, which will be beneficial for the vLLM community. These changes enhance the accuracy and expand the capabilities of low-bit precision operations within the system.

Highlights

  • FP4 GEMM Conversion Bug Fix: The _tir_u8_to_f4_to_f16 and _convert functions in example_dequant_gemm_fp4_hopper.py were updated to correct the exponential bias calculation for FP4 to FP16 conversion. Specifically, the e_f4 calculation was changed from f4 & 7 to (f4 & 6) >> 1, and e_f16 from e_f4 | 8 to e_f4 + 14. Additionally, mantissa bits (m_f4 and m_f16) are now correctly incorporated into the val_f16 calculation, ensuring accurate floating-point representation.
  • BF16xMXFP4 GEMM Kernel Implementation: A new file, example_dequant_gemm_mxfp4_hopper.py, has been introduced. This file implements a BF16xMXFP4 GEMM kernel, including a _tir_u8_to_f4_to_bf16 function for converting 4-bit unsigned integers to bfloat16 with proper exponential bias and scaling. It also provides convert and convert_scale functions, and a matmul kernel for performing General Matrix Multiply (GEMM) operations with BF16xMXFP4 precision, supporting optional scaling.
  • New Test Case for BF16xMXFP4 GEMM: The test_example_dequantize_gemm.py file was modified to include a new test function, test_example_dequant_gemm_mxfp4_hopper, which validates the correctness of the newly implemented BF16xMXFP4 GEMM kernel. This ensures the new low-bit precision operations function as expected.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in the FP4 to FP16 conversion for Hopper and introduces a new BF16xMXFP4 GEMM kernel example. The bug fix appears correct. The new example is a valuable addition, but I've identified a critical bug in shared memory allocation, along with several opportunities for improvement in code clarity, consistency, and performance. My review includes suggestions to address these points.

Comment on lines 286 to 287
Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype)
Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug in the allocation of Scale_shared and Scale_local. Their shapes are defined using block_M, but they should be using block_K. The Scale tensor's dimensions are tied to N and K, and the copy operation T.copy(Scale[bx * block_N, k * block_K // scale_size], Scale_shared) expects a destination with a shape based on block_K. Using block_M will likely cause out-of-bounds memory access and incorrect results.

Suggested change
Scale_shared = T.alloc_shared((block_N, block_M // scale_size), storage_dtype)
Scale_local = T.alloc_fragment((block_N, block_M // scale_size), storage_dtype)
Scale_shared = T.alloc_shared((block_N, block_K // scale_size), storage_dtype)
Scale_local = T.alloc_fragment((block_N, block_K // scale_size), storage_dtype)

# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on this line is misleading. It states that a max function is used to handle overflow, but the implementation uses T.min to clamp the value. To avoid confusion, the comment should be updated to accurately reflect the code.

Suggested change
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
# To handle the overflow, we use the min function to clamp the exponential part to 8 bits

Comment on lines +37 to +40
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function print_bit is defined here but is not used anywhere in the torch_convert function. This dead code should be removed to improve maintainability.

Comment on lines +49 to +56
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables e_f16 and val_f16 are used in the process of converting to bfloat16. This is confusing as the f16 suffix suggests float16. For better clarity and consistency with _tir_u8_to_f4_to_bf16, please rename them to e_bf16 and val_bf16 respectively.

Suggested change
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
e_bf16 = e_f4 + 126
if scale is not None:
e_bf16 = min(e_bf16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_bf16 = (((e_bf16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_bf16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)

Comment on lines 238 to 241
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Scale_shared buffer is missing a layout annotation here, while it is present in the main function for the split=1 case. For consistency and to ensure optimal performance, you should add a swizzled layout for Scale_shared in this main_split function as well.

Suggested change
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
Scale_shared: tilelang.layout.make_swizzled_layout(Scale_shared),
})

@LeiWang1999 LeiWang1999 merged commit 569b012 into tile-ai:main Aug 10, 2025
3 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [MXFP4] Dequantize FP4 kernel example, MX scale todo

* [BugFix] Fix the bug of fp4&fp16 exponential bias

* [MXFP4] Add group scale factor for BF16xMXFP4 gemm

* [Lint]

* [Test] Add test script for BF16xMXFP4 gemm

* [Lint]

* [BugFix] Fix the shape of scale tensor

* Update example_dequant_gemm_fp4_hopper.py

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants