Skip to content

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Aug 21, 2025

Purpose

Integrate aiter custom all reduce in cuda communicator, which boosts model performance.
This PR is tested on 5ee37dc commit from aiter package.

Benchmark Results

deepseek-ai/DeepSeek-V3 tp8

Metric With Aiter CustomAllreduce Without AITER CustomAllreduce
Request Throughput (req/s) 1.82 1.73
Output Token Thpt (tok/s) 494.16 484.04
Total Token Thpt (tok/s) 2311.82 2212.63
Mean TTFT (ms) 50.73 381.81
Median TTFT (ms) 48.39 389.49
P99 TTFT (ms) 64.95 709.10
Mean TPOT (ms) 22.98 117.79
Median TPOT (ms) 22.99 101.26
P99 TPOT (ms) 25.47 434.76
Mean ITL (ms) 22.92 25.73
Median ITL (ms) 22.98 23.73
P99 ITL (ms) 25.43 151.28

meta-llama/Llama-4-Scout-17B-16E-Instruct tp8

Metric With Aiter CustomAllreduce Without AITER CustomAllreduce
Request Throughput (req/s) 3.05 2.98
Output Token Thpt (tok/s) 913.81 859.67
Total Token Thpt (tok/s) 3935.91 3814.10
Mean TTFT (ms) 122.87 150.23
Median TTFT (ms) 108.24 148.77
P99 TTFT (ms) 190.81 227.81
Mean TPOT (ms) 18.20 20.51
Median TPOT (ms) 16.82 18.56
P99 TPOT (ms) 27.95 34.66
Mean ITL (ms) 15.93 16.87
Median ITL (ms) 12.88 13.30
P99 ITL (ms) 79.62 90.90

Qwen/Qwen3-235B-A22B-FP8 tp4

Metric With Aiter CustomAllreduce Without AITER CustomAllreduce
Request Throughput (req/s) 1.69 1.61
Output Token Thpt (tok/s) 1079.88 1116.64
Total Token Thpt (tok/s) 2761.42 2723.95
Mean TTFT (ms) 543.23 590.96
Median TTFT (ms) 550.70 600.72
P99 TTFT (ms) 990.83 1012.52
Mean TPOT (ms) 48.90 36.36
Median TPOT (ms) 27.36 28.08
P99 TPOT (ms) 428.00 143.51
Mean ITL (ms) 27.16 28.30
Median ITL (ms) 23.75 26.43
P99 ITL (ms) 131.04 29.65

benchmark setting

python vllm/benchmarks/benchmark_serving.py --backend vllm --model "$model_name" --dataset-name random --num-prompts 50 --request-rate 10 --random-input-len 1000 --random-output-len 1000

Test Plan

Test model that are afftected by this change, using lm_eval on gsm8k dataset.

environment setting

Step 1: run vllm serve

VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE=1 SAFETENSORS_FAST_GPU=1

vllm serve $MODEL_NAME --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' --trust-remote-code --max-model-len 32768 -tp 8 --block-size 1 --swap-space 16 --distributed-executor-backend mp

Step 2: run lm_eval

lm_eval --model local-completions --tasks gsm8k --model_args model=$MODEL_NAME,base_url=http://localhost:8000/v1/completions --trust_remote_code --num_fewshot 5 --batch_size 100

Test Results

deepseek-ai/DeepSeek-V3 tp8

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9477 ± 0.0061
strict-match 5 exact_match 0.9469 ± 0.0062

zejunchen-zejun and others added 3 commits August 20, 2025 05:00
control it by the env flag VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE
(default: True)

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Aug 21, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@vllmellm vllmellm marked this pull request as ready for review August 21, 2025 11:20
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor this to extend the current dispatching instead of using conditional imports?

@vllmellm
Copy link
Contributor Author

Can you refactor this to extend the current dispatching instead of using conditional imports?

@ProExpertProg The requested modification is applied.

…y and fix pre-commit error

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@ilmarkov
Copy link
Contributor

@vllmellm Could you also benchmark against QuickReduce in vllm? It is another alternative to custom allreduce for Rocm which has good speedup numbers. It can be enabled by this env variable.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Just nits.

"""Dispatch the custom allreduce implementation based on the platform."""
if is_rocm_aiter_custom_allreduce_enabled():
from aiter.dist.custom_all_reduce import CustomAllreduce
logger.info("Using aiter.dist.custom_all_reduce for ROCm platform")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: info_once

)

self.ca_comm: Optional[CustomAllreduce] = None
self.ca_comm: Optional[CustomAllreduce] = None # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think if you add a __call__ method to the CustomAllreduceProtocol you can get rid of the #type: ignore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore unfortunately, this solution didnt workout. to get rid of #type: ignore had to change the dtype into CustomAllreduceProtocol and implemented additional methods and attributes required by the class CustomAllreduce with this now mypy passes.

@ilmarkov
Copy link
Contributor

@vllmellm @SageMoore I would suggest to have a new aiter_comm not as complete replacement of current custom allreduce but as an alternative to it, and enable/disable it by env or config.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm
Copy link
Contributor Author

@vllmellm @SageMoore I would suggest to have a new aiter_comm not as complete replacement of current custom allreduce but as an alternative to it, and enable/disable it by env or config.

@ilmarkov Thank you for the suggestion. Indeed, having a separate module as aiter_comm is a cleaner solution. However, we are limited and bound by the implementation in the aiter package, where the CustomAllreduce is an exact duplicate of the one in the vLLM framework with some modifications for enhancements on the ROCm platform. Thus, the intention is to replace this module rather than extend the current one in the vLLM framework. It would be difficult to extend on top of vLLM's existing module unless we moved the enhancement logic from the aiter package to vLLM, which doesn't sound optimal in terms of maintenance and code readability.

I have added back separate individual environment flag for this based on your comment for enabling/disabling.

@ilmarkov
Copy link
Contributor

@vllmellm Sorry, if I unclearly shared the idea. I am suggesting to add aiter_comm independently on the existing ca_comm, not to extend the CustomAllreduce class. Just have a new comm near ca_comm in CudaCommunicator Similar to existing pynccl_comm and qr_comm.

@vllmellm
Copy link
Contributor Author

vllmellm commented Aug 29, 2025

@vllmellm Sorry, if I unclearly shared the idea. I am suggesting to add aiter_comm independently on the existing ca_comm, not to extend the CustomAllreduce class. Just have a new comm near ca_comm in CudaCommunicator Similar to existing pynccl_comm and qr_comm.

@ilmarkov I see what you mean—thank you for elaborating. I understand your point of view; however, the CustomAllreduce from aiter is literally a duplicate of the one in vLLM, and wherever ca_comm is referenced, we would need to check whether to use aiter and then use the aiter_comm, which would make the code cluttered with if/else statements.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@ilmarkov
Copy link
Contributor

ilmarkov commented Sep 8, 2025

@vllmellm Isn't there any scenario when we can use aiter_comm for certain range of input sizes and ca_comm for the other, e.g. for better performance?

@tjtanaa
Copy link
Contributor

tjtanaa commented Sep 8, 2025

@ilmarkov could you guide us on how we could evaluate the speed of the quick reduce while also take into consideration of the quantization error that is introduced by quick reduce? How should we set the quantization level? VLLM_ROCM_QUICK_REDUCE_QUANTIZATION=[NONE|FP|INT8|INT6|INT4]
Or should we only compare to VLLM_ROCM_QUICK_REDUCE_QUANTIZATION=FP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants