- 
                Notifications
    
You must be signed in to change notification settings  - Fork 533
 
[WIP][BugFix]Fix accuracy issues caused by wrong etp_size passed into FusedMoEParallelConfig when using vLLM 0.9.0 #961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    Conversation
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
    Signed-off-by: angazenn <zengyanjia@huawei.com>
              
                    ganyi1996ppo
  
              
              approved these changes
              
                  
                    May 27, 2025 
                  
              
              
            
            
    
  zxdukki 
      pushed a commit
        to zxdukki/vllm-ascend
      that referenced
      this pull request
    
      Jun 3, 2025 
    
    
      
  
    
      
    
  
… FusedMoEParallelConfig when using vLLM 0.9.0 (vllm-project#961) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
    
  chopper0126 
      pushed a commit
        to chopper0126/vllm-ascend
      that referenced
      this pull request
    
      Oct 16, 2025 
    
    
      
  
    
      
    
  
… FusedMoEParallelConfig when using vLLM 0.9.0 (vllm-project#961) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
    
  Angazenn 
      added a commit
        to Angazenn/vllm-ascend
      that referenced
      this pull request
    
      Oct 21, 2025 
    
    
      
  
    
      
    
  
… FusedMoEParallelConfig when using vLLM 0.9.0 (vllm-project#961) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR fix accuracy issues incurred by codes that adapt to `FusedMoEParallelConfig` in vLLM 0.9.0 version. The `tp_size` used to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel. vLLM: vLLM use a flag `enable_expert_parallel` to indicate whether to use EP and use the following codes to decide `ep_size`: ``` use_ep = (dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: return FusedMoEParallelConfig(tp_size=tp_size, tp_rank=tp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, tp_rank=0, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True) ``` vLLM-Ascend: vLLM-Ascend uses `etp` to specify Tensor Parallel in MoE. ``` self.ep_size = get_ep_group().world_size self.tp_size = get_etp_group().world_size self.dp_size = (dp_size if dp_size is not None else get_dp_group().world_size) ``` So there will be conflicts if we simply combine these codes together. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
      
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
What this PR does / why we need it?
This PR fix accuracy issues incurred by codes that adapt to
FusedMoEParallelConfigin vLLM 0.9.0 version. Thetp_sizeused to split weights are wrongly passed. The root cause is that vLLM community and vLLM-Ascend are using different methods to decide whether to use Expert Parallel.vLLM:
vLLM use a flag
enable_expert_parallelto indicate whether to use EP and use the following codes to decideep_size:vLLM-Ascend:
vLLM-Ascend uses
etpto specify Tensor Parallel in MoE.So there will be conflicts if we simply combine these codes together.
Does this PR introduce any user-facing change?
How was this patch tested?