Skip to content

Conversation

@wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Aug 29, 2025

Purpose

FIX #22916

Test Plan

accuracy:

VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="latency" \
/home/shuw/.local/bin/lm_eval --model vllm --model_args pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,data_parallel_size=4,enable_expert_parallel=True,tensor_parallel_size=1,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Perf:

VLLM_ALL2ALL_BACKEND="allgather_reducescatter" \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
VLLM_USE_STANDALONE_COMPILE=0 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="latency" \
  /home/shuw/.local/bin/vllm serve nvidia/DeepSeek-R1-FP4 \
    --quantization="modelopt_fp4" \
    --trust-remote-code \
    --max-model-len=2048 \
    --block-size=128 \
    --max-num-seqs=256 \
    --enable-expert-parallel \
    --gpu_memory_utilization=0.8 \
    --tensor-parallel-size 1 \
    --data-parallel-size 4
    
 python benchmarks/benchmark_serving.py \
  --model nvidia/DeepSeek-R1-FP4 \
  --dataset-name random \
  --ignore-eos \
  --num-prompts 256 \
  --max-concurrency 256 \
  --random-input-len 128 \
  --random-output-len 1024 
    
vs 
 VLLM_ALL2ALL_BACKEND="naive" \
 ...

Test Result

accuracy:

vllm (pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,data_parallel_size=4,enable_expert_parallel=True,tensor_parallel_size=1,enforce_eager=False,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9401|±  |0.0065|
|     |       |strict-match    |     5|exact_match|↑  |0.9386|±  |0.0066|

perf:

Allgather-ReduceScatter:
Successful requests:                     256       
Benchmark duration (s):                  120.41    
Total input tokens:                      32512     
Total generated tokens:                  262144    
Request throughput (req/s):              2.13      
Output token throughput (tok/s):         2177.05   
Total Token throughput (tok/s):          2447.06   
---------------Time to First Token----------------
Mean TTFT (ms):                          1598.62   
Median TTFT (ms):                        1599.74   
P99 TTFT (ms):                           1628.07   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          116.10    
Median TPOT (ms):                        116.10    
P99 TPOT (ms):                           116.11    
---------------Inter-token Latency----------------
Mean ITL (ms):                           116.10    
Median ITL (ms):                         115.87    
P99 ITL (ms):                            124.38    
Naive
Successful requests:                     256       
Benchmark duration (s):                  236.75    
Total input tokens:                      32512     
Total generated tokens:                  262144    
Request throughput (req/s):              1.08      
Output token throughput (tok/s):         1107.26   
Total Token throughput (tok/s):          1244.58   
---------------Time to First Token----------------
Mean TTFT (ms):                          2660.87   
Median TTFT (ms):                        2665.06   
P99 TTFT (ms):                           2692.93   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          228.78    
Median TPOT (ms):                        228.78    
P99 TPOT (ms):                           228.79    
---------------Inter-token Latency----------------
Mean ITL (ms):                           228.78    
Median ITL (ms):                         224.91    
P99 ITL (ms):                            271.71    

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Comment on lines +1035 to +1036
# - "allgather_reducescatter": all2all implementation based on allgather and
# reducescatter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth wdyt if we turn this on as the default?

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a better default.
[edit] I'd recommend making this switch in a followup PR with some testing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should make this the default

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and clean change. Thanks @wenscarl

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to add a prepare/finalize implementation that uses this backend?

@wenscarl
Copy link
Contributor Author

Don't we need to add a prepare/finalize implementation that uses this backend?

Not a this moment. Since trtllm-moe kernel's API is quite different from what modular kernel supports.

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed performance Performance-related issues moe labels Sep 11, 2025
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize how this was hooked up. Tried it and LGTM!

I also switched the default All2All backend to this one.

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) September 11, 2025 21:31
@simon-mo
Copy link
Collaborator

@tlrmchlsmth test should be unrelated? can you confirm?

@tlrmchlsmth
Copy link
Member

failures might actually be related as well changed the default All2All backend. I’m on my phone so can’t read the logs well right now

tlrmchlsmth and others added 2 commits September 12, 2025 16:02
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@nvpohanh
Copy link
Contributor

If we can confirm that the performance is good, could we enable it by default so that users do not need to add VLLM_ALL2ALL_BACKEND="allgather_reducescatter" env var?

cc @weireweire

@tlrmchlsmth
Copy link
Member

tlrmchlsmth commented Sep 15, 2025

If we can confirm that the performance is good, could we enable it by default so that users do not need to add VLLM_ALL2ALL_BACKEND="allgather_reducescatter" env var?

cc @weireweire

Yes, I had this enabled by default but backed out that change in case it was breaking the CI. Distributed tests are still red though, so not sure what the problem is now. Will try to land this PR with the default set to allgather_reducescatter

@mergify
Copy link

mergify bot commented Sep 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 17, 2025
@tlrmchlsmth
Copy link
Member

@wenscarl Could the failure be related? The distributed tests are green on a very recent nightly https://buildkite.com/vllm/ci/builds/30655#01994662-1fcb-4cd8-aba7-ee88e6f8608a

Signed-off-by: Shu Wang <shuw@nvidia.com>
auto-merge was automatically disabled September 18, 2025 13:30

Head branch was pushed to by a user without write access

@wenscarl
Copy link
Contributor Author

@wenscarl Could the failure be related? The distributed tests are green on a very recent nightly https://buildkite.com/vllm/ci/builds/30655#01994662-1fcb-4cd8-aba7-ee88e6f8608a

I didn't see significant failures. Do I miss anything?

@mergify mergify bot removed the needs-rebase label Sep 18, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) September 18, 2025 15:46
@tlrmchlsmth
Copy link
Member

🤞 that the distributed tests are green

@tlrmchlsmth tlrmchlsmth added this to the v0.11.0 milestone Sep 18, 2025
@tlrmchlsmth tlrmchlsmth merged commit 2ea50e9 into vllm-project:main Sep 18, 2025
49 checks passed
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…t#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
@zejunchen-zejun
Copy link
Contributor

Hi, @wenscarl
Wonderful work! May I know if this backend can be compatible with the cuda graph? The dispatch(all gather) and combine(reduce scatter) can be captured by cuda graph right?
Thank you!

@nvpohanh
Copy link
Contributor

Hi, @wenscarl Wonderful work! May I know if this backend can be compatible with the cuda graph? The dispatch(all gather) and combine(reduce scatter) can be captured by cuda graph right? Thank you!

we have been running with cuda graph and didn't see any issue with allgather_reducescatter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

moe performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature][Kernel][B200]: FI MoE LL does not use allgatherv and reduce-scatterv for dispatch and combine

7 participants