Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing #12501

Merged
merged 9 commits into from
Feb 7, 2025

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jan 28, 2025

Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing

Co-authored by: @kliuae

Note: This PR feature requires ROCm 6.3 and later and GPU Arch MI300 and later.

Description

This PR involves the following enhancements

  1. This is a PR specific to support Per-Token-Activation Per-Channel-Weight (PTPC-FP8) FP8 Quantization Inferencing.
    The model will be quantized on-the-fly from BFloat16 to FP8. Model weight which are store in Float16 will need to be casted into BFloat16.

  2. It used PyTorch latest rowwise scaled GEMM feature in torch._scaled_mm which is introduced in [ROCm] hipblaslt rowwise f8 gemm pytorch/pytorch#144432 , which speeds up current naive implementation by at least 2 times. For more details check out the Performance section

  • To support this feature, the Dockerfile.rocm_base PyTorch repo commit has been updated to 3a585126.
  • Dockerfile.rocm is left untouched as the base image is referencing to AMD docker hub registry. That base image at this point in time has already installed with PyTorch repo commit 3a585126.
  1. Small enhancement. The documentation has been updated to ROCm 6.3 and various commits in the installation step has been updated to match the commits in Dockerfile.rocm_base.

Performance

Perplexity Test

Model: Llama-3.1-8B-Instruct
Dataset: Wikitexts
GPU: MI300X

Model Quantization KVCacheDtype Tasks Metric Metric Score
Llama-3.1-8B-Instruct/ auto (bf16) auto (bf16) wikitext word_perplexity 9.4281
Llama-3.1-8B-Instruct/ fp8 fp8_e4m3 wikitext word_perplexity 9.5124
Llama-3.1-8B-Instruct/ ptpc_fp8 fp8_e4m3 wikitext word_perplexity 9.5093
Llama-3.1-8B-Instruct/ ptpc_fp8 (naive) fp8_e4m3 wikitext word_perplexity 9.5095

Speed Test (Old naive implementation vs torch._scaled_mm rowwise scaled GEMM feature)

Model: Llama-3.1-70B-Instruct
Dataset: SharedGPT
GPU: 1xMI300X

Quantization KVCacheDType Req/s Total token/s Output tokens/s
ptpc_fp8 (naive) fp8_e4m3 2.43 1003.46 481.28
ptpc_fp8 (torch._scaled_mm rowwise scaled GEMM feature) fp8_e4m3 6.36 2631.04 1261.91

PTPC_FP8 (naive)


  # Making sure the dummy tensor is on the same device as the weight
  global TORCH_DEVICE_IDENTITY
  if TORCH_DEVICE_IDENTITY.device != weight.device:
      TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)

  # GEMM
  # This computes C = (X * W).
  # Output in fp32 to allow subsequent ops to happen in-place
  output = torch._scaled_mm(qinput,
                            weight,
                            scale_a=TORCH_DEVICE_IDENTITY,
                            scale_b=TORCH_DEVICE_IDENTITY,
                            out_dtype=torch.float32)
  # A fix for discrepancy in scaled_mm which returns tuple
  # for torch < 2.5 and a single value in torch >= 2.5
  if type(output) is tuple and len(output) == 2:
      output = output[0]
  # Unpad (undo num_token_padding)
  output = torch.narrow(output, 0, 0, input_2d.shape[0])
  x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])

  # DQ
  # C = sw * sx * (X * W) + bias
  output = output * x_scale * weight_scale.t()

Additional Side Fixes while working on this PR

Fix FP8 unit tests for ROCm

@pytest.mark.skipif(not is_quant_method_supported("fp8"),
                    reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
    with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

        def check_model(model):
            attn = model.model.layers[0].self_attn.attn

            assert isinstance(attn.quant_method, Fp8KVCacheMethod)

+            if not current_platform.is_rocm():
                # NOTE: This code path requires validation on Non-CUDA platform
                # NOTE: it is valid for scales to be 1.0 (default value), but
                # we know these checkpoints have scales < 1.0
                assert 0.0 < attn._k_scale < 1.0
                assert 0.0 < attn._v_scale < 1.0
+            else:
+                # NOTE: This code path is for ROCm platform
+                # NOTE: it is valid for scales to be 1.0 (default value), but
+                # we know these checkpoints have scales < 1.0
+                # However on ROCm platform, the _k_scale and _v_scale will be
+                # scaled by a factor of 2 as described in
+                # vllm/model_executor/layers/quantization/kv_cache.py
+                assert 0.0 < attn._k_scale < (1.0 * 2.0)
+                assert 0.0 < attn._v_scale < (1.0 * 2.0)

        llm.apply_model(check_model)

        # note: this does not test accuracy, just that we can run through
        # see lm-eval tests for accuracy
        outputs = llm.generate_greedy(prompts=["Hello my name is"],
                                      max_tokens=10)
        print(outputs[0][1])

kliuae and others added 3 commits January 28, 2025 06:32
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
…lation readme to point to ROCm6.5

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. Do you think it would be possible to implement this inside of fp8.py? It seems we could just change the default --quantization fp8 on an unquantized model to use per-token and per-channel. Given the cutlass and pytorch support we have now, I don't think there is a great reason to rely on per-tensor by default anymore

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 29, 2025

Thanks for your contribution. Do you think it would be possible to implement this inside of fp8.py? It seems we could just change the default --quantization fp8 on an unquantized model to use per-token and per-channel. Given the cutlass and pytorch support we have now, I don't think there is a great reason to rely on per-tensor by default anymore

@mgoin We think it should be possible to implement PTPC-FP8 inside of fp8.py and force the --quantization fp8 on an unquantized model to use per-token and per-channel on ROCm. The default behavior of --quantization fp8 on NVIDIA GPU would require an additional PR to resolve it.

@hongxiayang @mgoin Maybe we could get the input from AMD to check if there is any preference or demand in maintaining a per-tensor quantization for backward compatibility.

Since we are on this topic, I remember making vLLM production ready is a goal, I wonder if there is a need for us to maintain certain backward compatibility so that the behavior of features does not change as much as possible?
Moreover, when we were adding this new quantization feature and wanted to add documentation about the quantization feature e.g. behavior, usage, expectation, we couldn't find a page for it. I wonder if there is any RFC for documentation about quantization approach?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion, it would be best to keep this separate from the fp8 quantization and treat the Nvidia case separately. I think this is good to land with these last requests

… code path; add skip test comment

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa changed the title [ROCm] [Feature] [Doc] [Dockerfile] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing [ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing Feb 1, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Feb 5, 2025

@mgoin I have addressed the comments. Is it ready to for merging?

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 7, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 7, 2025 14:44
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 7, 2025 14:44
@simon-mo simon-mo merged commit eaa92d4 into vllm-project:main Feb 7, 2025
70 of 72 checks passed
AoyuQC pushed a commit to AoyuQC/vllm that referenced this pull request Feb 8, 2025
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
SzymonOzog pushed a commit to SzymonOzog/vllm that referenced this pull request Feb 12, 2025
…tion Per-Channel-Weight FP8 Quantization Inferencing (vllm-project#12501)

Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants