Skip to content

Conversation

@wuxun-zhang
Copy link
Contributor

@wuxun-zhang wuxun-zhang commented Sep 28, 2025

After vllm-project/vllm#24982 merged, sequence parallel MOE will be turned on when enable_expert_parallel=True, tp_size > 1 and dp_size > 1. Since for Gaudi, there is no choice for VLLM_ALL2ALL_BACKEND, we can not easily bypass it. So this PR aims to support the feature.

class ParallelConfig:

  @property
    def use_sequence_parallel_moe(self) -> bool:
        return (envs.VLLM_ALL2ALL_BACKEND
                in ("allgather_reducescatter", "naive",
                    "deepep_high_throughput", "deepep_low_latency")
                and self.enable_expert_parallel
                and self.tensor_parallel_size > 1
                and self.data_parallel_size > 1)

Update:
No hard requirement on vllm-project/vllm#25828

Signed-off-by: Wuxun Zhang <wuxun.zhang@intel.com>
Signed-off-by: Wuxun Zhang <wuxun.zhang@intel.com>
@github-actions
Copy link

✅ CI Passed

All checks passed successfully against the following vllm commit:
c242c98031b87d00999e07dbb4aa9b2a70798c6c

@xuechendi
Copy link
Collaborator

@wuxun-zhang , we are trying to make bs,seq_len, hidden_state not hard-requirement for HPU.
Please check with this new flag: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/extension/features.py#L89

Please try with add GraniteMOE to flatten input_ids to 1D.

Meanwwhile, I also discussed with @kzawora-intel , since more and more models asserts on 2D input, we might change to 1D as default once performance validated.

@wuxun-zhang
Copy link
Contributor Author

Please try with add GraniteMOE to flatten input_ids to 1D.

Thanks, it works. Just updated.

@xuechendi xuechendi enabled auto-merge (squash) September 30, 2025 01:07
@xuechendi xuechendi merged commit 922a18f into vllm-project:main Sep 30, 2025
34 checks passed
@github-actions
Copy link

✅ CI Passed

All checks passed successfully against the following vllm commit:
c242c98031b87d00999e07dbb4aa9b2a70798c6c

iboiko-habana pushed a commit to iboiko-habana/vllm-gaudi that referenced this pull request Oct 2, 2025
After vllm-project/vllm#24982 merged, sequence
parallel MOE will be turned on when `enable_expert_parallel=True`,
`tp_size > 1` and `dp_size > 1`. Since for Gaudi, there is no choice for
`VLLM_ALL2ALL_BACKEND`, we can not easily bypass it. So this PR aims to
support the feature.

```python
class ParallelConfig:

  @Property
    def use_sequence_parallel_moe(self) -> bool:
        return (envs.VLLM_ALL2ALL_BACKEND
                in ("allgather_reducescatter", "naive",
                    "deepep_high_throughput", "deepep_low_latency")
                and self.enable_expert_parallel
                and self.tensor_parallel_size > 1
                and self.data_parallel_size > 1)

```

Update:
No hard requirement on vllm-project/vllm#25828

---------

Signed-off-by: Wuxun Zhang <wuxun.zhang@intel.com>
Signed-off-by: Iryna Boiko <iboiko@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants