Skip to content

Conversation

@abmfy
Copy link
Member

@abmfy abmfy commented May 19, 2025

This PR introduces support for dynamic load balancing in expert parallelism (EP) for the deployment of Mixture-of-Experts (MoE) models.

Dynamic load balancing is essential for auxiliary-loss-free MoE models, such as the DeepSeek-V3/R1 series. This feature enables dynamic rearrangement of experts across different ranks/nodes to achieve better load balance during inference.

Additionally, this PR introduces support for redundant experts, allowing each routed expert to maintain multiple parameter copies distributed across different ranks. This further improves expert load balancing.

Running

To try out EPLB, enable it with the following options:

--enable-eplb
--num-redundant-experts 32
--eplb-window-size 1000
--eplb-step-interval 3000

You should see a log message indicating that EPLB is enabled, as well as periodic logs showing the rearrangement of experts.

Compatibility

Currently, we support DeepSeek-V2, V3, and R1 models with FP8 quantization. However, this PR has been designed with generality in mind, so extending support to other MoE models or quantization methods should be straightforward.

Adding model support:

To add support for a new model, implement the MixtureOfExperts protocol. In essence, you’ll need to:

  • Expose relevant MoE configurations.
  • Provide access to the expert weights that need to be shuffled.
  • Forward EPLB-related information into the FusedMoE layer.

Note: Pay close attention to the weight-loading logic. With redundant experts, you’ll need to handle additional complexity to ensure weights are loaded correctly. The expert_params_mapping returned by FusedMoE reflects the presence of redundant experts, but you may need to implement some nontrivial adjustments in the model class to prevent breaking the weight-loading process.

You can refer to the implementation changes in deepseek_v2.py.

Adding quantization support:

Adding quantization support should be straightforward, as it mainly involves forwarding the necessary arguments.

See the changes in fp8.py for reference.

We welcome contributions to help add support for additional models and quantization methods!

To-Dos

To-Do List for this PR:

  • Implement replicated experts in fused MoE operations
  • Monitor expert balancedness in metrics
  • Remove magic numbers
  • Allow turning off monitoring since it brings some overhead

Long-term To-Do List (should be done in other PRs):

  • Model Execution
    • When using FusedMoEModularKernel, we can directly use the load metrics returned by FusedMoEPrepareAndFinalize, instead of calculating them inside expert selection. We're not doing this since not all code paths are using FusedMoEModularKernel now
  • EPLB Algorithm
    • Add other rebalancing strategies, e.g. rebalance when balancedness falls below some threshold
    • Consider treating differently for prefill and decode nodes in the rearrangement algorithm
  • EPLB Execution
    • Parallelize the rearrangement algorithm (calculating new expert mapping, not the communication)
    • Shuffle one layer at once and use multiple steps, to lower the impact on inter-token latency
    • Investigate should we pre-allocate expert weight buffer used for transferring
    • Take locality into consideration in expert weight transmission, e.g. prioritize transferring to GPUs on the same node
  • Compatibility
    • Add support for DeepSeek Multi-Token Prediction (MTP) layers
    • Add support for two-batch overlap ([WIP] Two batch overlap #18415)
    • Add support for other MoE models, e.g. Llama 4, Qwen3
    • Add support for other quantization methods
  • API
    • Group EPLB configs together

abmfy added 7 commits May 14, 2025 16:29
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
WIP, design choices not finalized.

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
@mergify mergify bot added the v1 label May 19, 2025
@mergify
Copy link

mergify bot commented May 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @abmfy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Bowen Wang <abmfy@icloud.com>
abmfy added 2 commits May 20, 2025 18:02
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
abmfy added 11 commits May 23, 2025 20:19
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Moved into `FusedMoE` layers

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Signed-off-by: Bowen Wang <abmfy@icloud.com>
Since `grouped_topk` will assume top-2 for DeepSeek-V3

Signed-off-by: Bowen Wang <abmfy@icloud.com>
@WoosukKwon WoosukKwon merged commit e9fd658 into vllm-project:main Jun 26, 2025
96 of 101 checks passed
@ztxdcyy
Copy link

ztxdcyy commented Jun 27, 2025

🎉 So happy to see this PR finally merged after going through so many challenges — big round of applause for the researcher's persistence and dedication! @abmfy 👏👏👏

Also, just wondering — how can we measure the benefits brought by EPLB? 🤔
Things like expert balancing, GPU utilization, TTFT, TPOT... Any suggestions or best practices? 💡

@abmfy
Copy link
Member Author

abmfy commented Jun 27, 2025

🎉 So happy to see this PR finally merged after going through so many challenges — big round of applause for the researcher's persistence and dedication! @abmfy 👏👏👏

Also, just wondering — how can we measure the benefits brought by EPLB? 🤔 Things like expert balancing, GPU utilization, TTFT, TPOT... Any suggestions or best practices? 💡

Hi @ztxdcyy, thanks for your attention!

There’s now a default-off option --eplb-log-balancedness that logs the load balance factor across different GPUs at each step.

As for other metrics, I believe they’re not specific to the EPLB settings, so we can simply rely on standard metrics by running benchmarks and monitoring those results as usual.

Let me know what you think!

@Lichunyan3
Copy link

@abmfy Hello, I'm encountering the following error when using multi-GPU parallel processing
Here's my startup command:
python -m vllm.entrypoints.openai.api_server --model="/public/models/hf_models/DeepSeek-V2-Lite-Chat-FP8-A16" --trust-remote-code -tp 2 -dp 2 --port 8200 --enforce-eager --enable-eplb --eplb-log-balancedness

This issue doesn't occur when starting with a single GPU and only appears during multi-GPU parallel processing. Have you encountered this before, or do you have any solutions?

微信图片_20250704160723

@abmfy
Copy link
Member Author

abmfy commented Jul 31, 2025

@abmfy Hello, I'm encountering the following error when using multi-GPU parallel processing Here's my startup command: python -m vllm.entrypoints.openai.api_server --model="/public/models/hf_models/DeepSeek-V2-Lite-Chat-FP8-A16" --trust-remote-code -tp 2 -dp 2 --port 8200 --enforce-eager --enable-eplb --eplb-log-balancedness

This issue doesn't occur when starting with a single GPU and only appears during multi-GPU parallel processing. Have you encountered this before, or do you have any solutions?

微信图片_20250704160723

Hi @Lichunyan3, sorry for the late reply—I was traveling.

It looks like you may have missed adding --enable-expert-parallel; EPLB requires running under EP.

We’ve added some checks in #21102, so if EPLB is enabled without EP, it will now raise an error.

@Bounty-hunter
Copy link

did you test how balancedness imporve in benchmark_serving.py? It’s a random dataset. Will there be a significant improvement?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.