Skip to content

Conversation

@cakeng
Copy link
Contributor

@cakeng cakeng commented Jan 30, 2025

The current vLLM execution only supports TP when running MoE models.

This PR adds support for Expert Parallelism (EP) for the FusedMoE Kernel and DeepSeek V2 model, which should be extendable to V3 and other MoE models as well.

When VLLM_TEST_ENABLE_EP=1, EP is automatically applied on the FusedMoE layers using the user specified tensor_parallel_size as the expert parallel size.

CUDAgraph works with EP.

Doc

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comment but overall LGTM

@youkaichao
Copy link
Member

need to check the user-interface, but I feel right now we can just reuse the tp size for ep?

in the future, when we have DP, EP size will automatically be DP x TP.

@LucasWilkinson
Copy link
Collaborator

but the problem with CUDA graph seems to be on the current attention layer (MLA?) implementation.

can you please elaborate on this, a bit? MLA + CUDA graphs + TP is working fine on main as far as I am aware

@mergify mergify bot added the v1 label Feb 4, 2025
@cakeng
Copy link
Contributor Author

cakeng commented Feb 4, 2025

@youkaichao The current design support TP within an EP, but we can easily change that to have EP only on MoE layers. I think we will need more discussion with others on that design decision, the current implementation of EP+TP is based on a discussion with @WoosukKwon and @simon-mo.

@LucasWilkinson I just merged the main branch and CUDA graph is now working with EP+TP.

Signed-off-by: Jongseok Park <js_park@berkeley.edu>
Signed-off-by: Jongseok Park <js_park@berkeley.edu>
@simon-mo
Copy link
Collaborator

@cakeng a test failure in CI

https://buildkite.com/vllm/ci/builds/14038#019534e8-1eef-4be1-8e89-3f718502ddb4/2582-4097

[2025-02-23T23:32:34Z] FAILED models/test_initialization.py::test_can_initialize[QuantMixtralForCausalLM] - AttributeError: 'FusedMoE' object has no attribute 'num_experts'

…o benchmark_moe.py

Signed-off-by: Jongseok Park <js_park@berkeley.edu>
@simon-mo simon-mo merged commit 781096e into vllm-project:main Feb 24, 2025
54 of 55 checks passed
@simon-mo simon-mo changed the title Expert Parallelism (EP) Support for DeepSeek V2 Expert Parallelism (EP) Support for DeepSeek Models Feb 25, 2025
@cakeng cakeng deleted the moe branch February 26, 2025 08:33
@liweiqing1997
Copy link

Hello, I would like to inquire about the communication overhead of the current implementation.

It seems that regardless of whether the MOE's EP (Expert Parallelism) is enabled, each rank will still perform an all-reduce operation on the hidden states, so the amount of data transmitted during communication is the same whether EP is enabled or not.
The only difference is the communication latency, which is due to the need for all ranks to synchronize and wait for the MOE computations to finish when EP is enabled. I'm not sure if my analysis is correct.

I noticed that after enabling EP on a 16-card setup, the time taken for the all-reduce kernel has doubled, and the throughput has dropped by 30%.

@lewisword
Copy link

lewisword commented Mar 14, 2025

May I ask if this feature can be used in a service-oriented way? I see from the example in examples/offline_inference/data_parallel.py that it uses an offline multi-process invocation approach. @cakeng

@cakeng
Copy link
Contributor Author

cakeng commented Mar 18, 2025

@lewisword Yes it should work, but the API to enable EP has been changed to the --enable-expert-parallel engine argument (#14305). Also, the EP performance is not necessarily better than TP in the current implementations.

@xiuxin121
Copy link

vllm ep看上去还没有实现通信(dispatch+ combine)和计算的overlap?这在deepseek这样的moe模型中显得很重要。

@Neo9061
Copy link

Neo9061 commented Apr 20, 2025

Can I double check if this feature only works with --enforce-eager? when I turn this flag off, there is Dynamo related error.

Asking as the last statement of this PR description saying it works with CUDAGraph

shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.