-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[XPU] Fix MOE DP accuracy issue on XPU #25465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an accuracy issue with Mixture-of-Experts (MoE) models on XPU devices when using data parallelism. The core of the fix involves implementing the dispatch and combine communication primitives in the XpuCommunicator, which are essential for expert parallelism. The changes correctly delegate these operations to an all2all_manager. Additionally, the PR makes the all2all backend on XPU robust by defaulting to the naive implementation, which is the only one currently supported, and warns the user if a different backend is configured. The modifications to the data parallelism example script to make enable_expert_parallel a configurable argument is also a good improvement for flexibility. The provided test results clearly demonstrate the effectiveness of the fix. The changes are well-implemented and follow existing patterns in the codebase. Overall, this is a solid contribution to improve XPU support.
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
dispatchandcombinemethods in XpuCommunicator to fix the MOE model accuracy issue on XPUnaivethe default all2all_backend on XPUTest Plan
VLLM_WORKER_MULTIPROC_METHOD=spawn python examples/offline_inference/data_parallel.py --enforce-eager --model="ibm-research/PowerMoE-3b" --dp-size=2 --tp-size=2 --disable-expert-parallelBefore:
After:
supported_models.mdandexamplesfor a new model.