Skip to content

Conversation

@roikoren755
Copy link
Contributor

@roikoren755 roikoren755 commented Sep 30, 2025

Purpose

This PR fixes an issue with padded FP4 quantization, where previously initialized values in the allocated tensors aren't overwritten properly when quantizing to FP a tensor which requires padding.

This issue prevents running NVIDIA's new Nemotron-Nano-9B-v2 with TP2, as some of the tensor sizes in such cases do not divide evenly by 128.

Note that this specific model doesn't work with TP>2, and this PR doesn't solve issues that come up in those scenarios. For TP4 the issue comes from inside the FP4 GEMM kernel, where it expects some matrix size to be divisible by 32, and for TP8 the issue happens when creating the model weights, where some layers' in features are not divisible by 16. The same issues happen in TensorRT-LLM as well.

Test Plan

No need for new tests, as tests are already in place to verify FP4 quantization, both with and without padding.

Test Result

All existing FP4 quantization tests passed locally.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in padded FP4 quantization by ensuring that allocated tensors for quantized output and scales are zero-initialized. The previous use of torch.empty could lead to uninitialized garbage values in padded regions, causing incorrect and non-deterministic behavior. By switching to torch.zeros, the padded areas are correctly filled with zeros. Additionally, assertions that were incompatible with padded tensor shapes in flashinfer_scaled_fp4_mm have been correctly removed. These changes are essential for enabling models that require padding for FP4 quantization, such as Nemotron-Nano-9B-v2 with tensor parallelism. The fixes are correct and well-targeted.

@roikoren755 roikoren755 force-pushed the fix/enable_padded_fp4 branch from f733f88 to 3298b91 Compare October 6, 2025 07:12

# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output = torch.zeros((m, n // 2), device=device, dtype=torch.uint8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a tensor of required out shape is allocated, why do we also need to initialize it with zeros? The divisibility check is only on the reducing dimension, could you clarify this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it for consistency's sake, but you are correct, this allocation can stay as torch.empty. I'll revert this line.

@pavanimajety
Copy link
Collaborator

Thanks for the PR! Could you also post lm_eval results for any other FP4 model to ensure that previous paths don't break?

@mgoin mgoin self-requested a review October 7, 2025 19:40
…4_mm`

Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
@roikoren755 roikoren755 force-pushed the fix/enable_padded_fp4 branch from 3298b91 to 227f836 Compare October 8, 2025 08:19
@roikoren755
Copy link
Contributor Author

Thanks for the PR! Could you also post lm_eval results for any other FP4 model to ensure that previous paths don't break?

Ran GSM8K on NVFP4/Qwen3-30B-A3B-Instruct-2507-FP4 using the following command:

lm-eval --model vllm --model_args pretrained=NVFP4/Qwen3-30B-A3B-Instruct-2507-FP4,add_bos_token=true --tasks gsm8k --num_fewshot 5 --batch_size auto

And got the following results:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7453|±  |0.0120|
|     |       |strict-match    |     5|exact_match|↑  |0.7498|±  |0.0119|

Before my changes, the results are:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7445|±  | 0.012|
|     |       |strict-match    |     5|exact_match|↑  |0.7475|±  | 0.012|

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed bug Something isn't working quantization labels Oct 9, 2025
@mgoin
Copy link
Member

mgoin commented Oct 9, 2025

Note that this specific model doesn't work with TP>2, and this PR doesn't solve issues that come up in those scenarios. For TP4 the issue comes from inside the FP4 GEMM kernel, where it expects some matrix size to be divisible by 32, and for TP8 the issue happens when creating the model weights, where some layers' in features are not divisible by 16. The same issues happen in TensorRT-LLM as well.

This comment makes it seem like we should still be checking or asserting shapes at the kernel dispatch level to guide users, but we can leave that for future work if this fix is needed first.

@vllm-bot vllm-bot merged commit 4069db3 into vllm-project:main Oct 9, 2025
48 of 50 checks passed
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
@roikoren755 roikoren755 deleted the fix/enable_padded_fp4 branch October 20, 2025 07:30
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants