Skip to content

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Jun 10, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Purpose

This PR introduces a new kernel from the aiter package, gemm_a8w8_bpreshuffle_CK, which supports FP8 quantization of per-token activations and per-channel weights for linear layers.

This PR upgrades aiter package to a newer commit 636a9f0.

Important:

per_tensor_activations = (x_scale.numel() == 1)

-        per_tensor_weights = (weight_scale.numel() == 1)
-        per_tensor_activations = (x_scale.numel() == 1)
+       per_tensor_weights = (weight_scale.numel()
+                              == 1) and weight_scale.dim() < 2
+        per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2

We have added to check the dimension of the tensors due to the following reason in Per-Token Quantization Scheme:

weight_scale.shape: torch.Size([7168, 1]), x_scale.shape: torch.Size([1, 1]), channelwise: True

the x_scale.shape is the edge case that the condition fails to catch. This is when there is 1 token and it is a per-token scaled weight. x_scale has the shape (num_tokens, num_of_scale_value): torch.Size([1, 1])
Model that has been used to test this is EmbeddedLLM/Qwen2.5-7B-Instruct-FP8-Dynamic link

The following is the shape of the weight_scale and x_scale for Per-Tensor quantized model: RedHatAI/Meta-Llama-3-8B-Instruct-FP8-KV

x_scale.dim(): 0, weight_scale.dim(): 0
weight_scale.shape: torch.Size([]), x_scale.shape: torch.Size([]), channelwise: False

We used the condition x_scale.dim() < 2 to cover the case to handle the case where the per-tensor scale can take the form of zero-dimension tensor or single dimension tensor of size 1.

Test Plan

  1. test lm_eval on RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic by enabling AITER LINEAR only.

VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=0 VLLM_ROCM_USE_AITER_RMSNORM=0 SAFETENSORS_FAST_GPU=1 lm_eval --model vllm --model_args pretrained=RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic,tensor_parallel_size=4,max_model_len=4086 --tasks gsm8k --num_fewshot 5 --batch_size auto

  1. test lm_eval on RedHatAI/Qwen3-235B-A22B-FP8-dynamic by enabling AITER LINEAR only.

VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MOE=0 VLLM_ROCM_USE_AITER_RMSNORM=0 SAFETENSORS_FAST_GPU=1 lm_eval --model vllm --model_args pretrained=RedHatAI/Qwen3-235B-A22B-FP8-dynamic,tensor_parallel_size=4,max_model_len=4086 --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

  1. RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9174 ± 0.0076
strict-match 5 exact_match 0.9030 ± 0.0082
  1. RedHatAI/Qwen3-235B-A22B-FP8-dynamic
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8658 ± 0.0094
strict-match 5 exact_match 0.8514 ± 0.0098

Online Serving Benchmarking Results

commands:
vllm serve $model_name --distributed-executor-backend mp --swap-space 16 --disable-log-requests

python benchmarks/benchmark_serving.py --backend vllm --model "$model_name" --dataset-name random --num-prompts 50 --goodput "ttft:3000" "tpot:50" --random-input-len 1000 --random-output-len 1000 --random-range-ratio 0.9

Metric LLama-4-Scout-17B-16E-Instruct-FP8-dynamic No AITER LLama-4-Scout-17B-16E-Instruct-FP8-dynamic Yes AITER RedHatAI/Qwen3-235B-A22B-FP8-dynamic No AITER RedHatAI/Qwen3-235B-A22B-FP8-dynamic Yes AITER
Successful requests 50 50 50 50
Benchmark duration (s) 41.88 42.56 82.05 80.75
Total input tokens 50392 50392 50392 50392
Total generated tokens 15427 15825 32287 29502
Request throughput (req/s) 1.19 1.17 0.61 0.62
Request goodput (req/s) 1.10 1.03 0.05 0.04
Output token throughput (tok/s) 368.39 371.87 393.49 365.37
Total token throughput (tok/s) 1571.73 1556.03 1007.63 989.44
Mean TTFT (ms) 2105.58 2086.39 3644.37 3679.19
Median TTFT (ms) 2228.77 2086.58 3904.48 4015.55
P99 TTFT (ms) 3005.01 3028.35 5537.26 5496.08
Mean TPOT (ms) 40.71 38.53 93.71 76.33
Median TPOT (ms) 36.84 38.73 50.12 52.34
P99 TPOT (ms) 143.19 58.58 736.52 529.77
Mean ITL (ms) 31.80 32.68 49.54 50.81
Median ITL (ms) 30.04 31.74 46.23 47.32
P99 ITL (ms) 369.34 209.55 51.30 53.30

vllmellm added 3 commits June 9, 2025 12:58
Co-authored-by: tjtanaa tunjian.tan@embeddedllm.com
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @vllmellm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for using ROCm AITER kernels for W8A8 linear layers, specifically for per-token activation and per-channel weight quantization on MI3XX GPUs. It includes the necessary logic to enable the AITER kernel based on configuration, handles the required weight format, and adds tests to ensure correct functionality and compatibility with PyTorch's compilation features. The author has provided test results showing performance metrics with AITER enabled on specific FP8 dynamic models.

Highlights

  • ROCm AITER Integration: Integrated ROCm AITER kernels specifically for W8A8 (per-token activation, per-channel weight) GEMM operations, targeting MI3XX hardware.
  • Conditional Kernel Dispatch: Implemented logic to conditionally dispatch the AITER GEMM kernel based on platform (ROCm), hardware (MI3XX), and environment variables (VLLM_ROCM_USE_AITER, VLLM_ROCM_USE_AITER_LINEAR).
  • Weight Layout Handling: Modified the weight processing for W8A8 compressed tensors to handle AITER's requirement for a shuffled (N, K) weight layout, applying the shuffle during weight loading if AITER is enabled.
  • New AITER Tests: Added a new test file to verify the registration of the custom AITER GEMM op and its compatibility with torch.compile.
  • AITER Dependency Update: Updated the AITER dependency version specified in the ROCm base Dockerfile.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the ci/build label Jun 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates AITER GEMM w8a8 operations for ROCm, including tests, utility functions, and dispatch logic. Key areas for attention include numerical tolerances in tests, clarifying parameter usage, and adding comments to explain key steps.

@vllmellm vllmellm changed the title Aiter gemm w8a8 ptpc [ROCm][FEAT] Integrate AITER gemm w8a8 ptpc Jun 10, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Jun 11, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
Copy link

mergify bot commented Jun 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 22, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Jul 20, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Oct 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build qwen Related to Qwen models rocm Related to AMD ROCm stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant