Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Mar 17, 2025

Description

This PR is to introduce AITER Linear kernel so that any up-coming optimization in AITER kernel could be directly use and evaluated within vLLM framework.

(Updates) This PR has been updated to also support V1.

Given that the AITER ops is being used by multiple files, and requires to be registered as custom op through direct_register_custom_op for it to be used with V1, we proposed to keep all AITER related flags (helper functions) and ops within vllm/_aiter_ops.py. We would like to propose that moving forward, AITER ops will be first defined in this file and registered as custom ops using direct_register_custom_op. AITER related flags (helper functions) will be defined here as well.

We have added tests/v1/rocm/test_aiter_ops.py . It tests if the aiter ops are

  1. correctly registered as custom ops
  2. correctly defined the relationship between implementation and fake function
  3. can be used with torch.compile
    This file will be skipped if AITER is not installed and the platform is not ROCm.

NOTE:
This unit tests is by no means to check the correctness of the AITER ops. It only checks if the AITER ops are correctly registered and if torch.compile can be used with the AITER ops. The correctness of the AITER ops is tested in the https://github.com/ROCm/aiter

To meet those two criteria, the following checks are done.

Criteria 1

To check for fake tensor implementation is correct or not, it can be done using torch.library.check as follows:

    # Verify the op's fake implementation with FP8 inputs
    # Disable test_schema as fp8 datatype is not supported by
    # torch.library.opcheck
    # Related error:
    #      OpCheckError: opcheck(op, ...): test_schema failed with
    #      "mul_cuda" not implemented for 'Float8_e4m3fnuz'
    torch.library.opcheck(torch.ops.vllm.rocm_aiter_tuned_gemm,
                          (input_fp8, weight_fp8),
                          kwargs={
                              "out_dtype": torch.float16,
                              "scale_a": scale_a,
                              "scale_b": scale_b
                          },
                          test_utils=("test_faketensor"))

Criteria 2

Ensure the torch.compile mode and the eager mode generates the same results

    # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
    result_original = gemm_fp8_fn(input_fp8, weight_fp8, scale_a, scale_b)
    result_compiled = compiled_fp8_fn(input_fp8, weight_fp8, scale_a, scale_b)

What is AITER Linear?

AITER Linear is a kernel from the AI Tensor Engine for ROCm (AITER) that has been integrated into vLLM for the unquantized linear method and for per-tensor-weight and per-tensor-activation quantization FP8 Scaled GEMM.

How to Enable AITER Linear

  • It is enabled by default when the environment variable VLLM_ROCM_USE_AITER=1 is set
  • It can be specifically enabled or disabled using its dedicated environment variable VLLM_ROCM_USE_AITER_LINEAR
Out-dated-results

Performance Comparison with No-AITER

All tests are V0 Engine

Llama-3.1-8B-Instruct (with FP8 per-tensor dynamic quantization)

  • With Linear only: -2% to 2% performance improvement

Llama-3.1-8B-Instruct-BF16

  • With Linear only: -5% to 4% performance improvement

Llama-3.1-70B-Instruct (with FP8 per-tensor dynamic quantization)

  • With Linear only: -1% to 0.5% performance improvement

Llama-3.1-70B-Instruct-BF16

  • With Linear only: -0.3% to 0.2% performance improvement

Throughput Performance Comparison with No-AITER

Before PR: [Performance][ROCm] Add skinny gemms for unquantized linear on ROCm #15830

Settings:

  • Model: Llama-3.1-8B-Instruct
  • Cases (Input Token Length: Output Token Length):
    • 128: 128
    • 128: 2048
    • 2048: 128
    • 2048: 2048

Llama-3.1-8B-Instruct (with FP8 per-tensor dynamic quantization)

  • With Linear only (Performance gain):
    • 128: 128: 1.93x
    • 128: 2048: 1x
    • 2048: 128: 1x
    • 2048: 2048: 1x

Llama-3.1-8B-Instruct-BF16

  • With Linear only (Performance gain):
    • 128: 128: 1.75x
    • 128: 2048: 1x
    • 2048: 128: 1x
    • 2048: 2048: 1x

V1 Accuracy Test

Settings:

  • Model: Llama-3.1-8B-Instruct

Unquantized No AITER

vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,tensor_parallel_size=1,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.7824 ± 0.0114
strict-match 5 exact_match 0.7566 ± 0.0118

AITER Linear unquantized

vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,tensor_parallel_size=1,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.7824 ± 0.0114
strict-match 5 exact_match 0.7566 ± 0.0118

AITER Linear per-tensor FP8

vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,tensor_parallel_size=1,max_model_len=10000,quantization=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6414 ± 0.0132
strict-match 5 exact_match 0.6020 ± 0.0135

vllmellm and others added 3 commits March 15, 2025 07:27
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 17, 2025
tjtanaa added 5 commits March 17, 2025 03:10
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please split up the linear between unquantized and quantized? That way we can land unquantized now and the quantized soon after, once I'm done with the FP8 scaledmm refactor.

return tgemm.mm(x, weight, bias)


def dipsatch_unquantized_linear_func() -> Callable[..., torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good to specify the exact signature here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg We have fix the typo and specified the exact signature


def dispatch_unquantized_linear_func(
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
    from vllm._aiter_ops import is_rocm_aiter_linear_enabled
    if is_rocm_aiter_linear_enabled():
        return aiter_ops.rocm_aiter_tuned_gemm
    return F.linear

Copy link

mergify bot commented Mar 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch generally looks fine to me. Lets rebase it and clean it up a bit. Same as the other AITER PRs, can you post lm_eval results from models that this kernel should support?

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Apr 16, 2025

This patch generally looks fine to me. Lets rebase it and clean it up a bit. Same as the other AITER PRs, can you post lm_eval results from models that this kernel should support?

This kernel is for generic use case. So we will pick just one model and compute its lmeval to show the correctness of the implementation which is llama3.1.

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Apr 16, 2025
tjtanaa added 5 commits April 16, 2025 19:04
…move redundant env variables

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
tjtanaa added 3 commits April 19, 2025 16:14
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…y and the torch compile mode works

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot added the v1 label Apr 19, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Apr 20, 2025

@SageMoore @ProExpertProg We have updated the PR description with the latest information.
In this PR it will also enable the AITER kernels to be used with V1 Engine.

Should we keep the is_rocm_aiter_xxx_enabled flag in the platforms/rocm.py in the same place where use_rocm_custom_paged_attention and use_custom_allreduce are stored?

tjtanaa added 3 commits April 21, 2025 08:22
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>


def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
def cutlass_w8a8_scaled_mm(qinput: torch.Tensor, weight: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor
->
cutlass_w8a8_scaled_mm(qinput: torch.Tensor

why did we remove the starting *? is it on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is on purpose. The input argument name of the AITER tuned gemm is not call qinput thus will break the dispatcher function logic. So removing the * and avoid specifying the first argument as keyword argument.

@mergify mergify bot added the rocm Related to AMD ROCm label Jun 13, 2025
Copy link

mergify bot commented Jun 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Sep 12, 2025
Copy link

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions bot closed this Oct 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build needs-rebase rocm Related to AMD ROCm stale Over 90 days of inactivity v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants