Skip to content

Commit d3d649e

Browse files
hmellorIsotr0py
andauthored
Support expert parallel in Transformers backend (#26162)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent ea507c3 commit d3d649e

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

docs/models/supported_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus
3232
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
3333
- Any combination of the following vLLM parallelisation schemes:
3434
- Data parallel
35-
- Pipeline parallel
3635
- Tensor parallel
36+
- Expert parallel
37+
- Pipeline parallel
3738

3839
Checking if the modeling backend is Transformers is as simple as:
3940

vllm/model_executor/models/transformers_moe.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from vllm.compilation.decorators import support_torch_compile
2424
from vllm.config.utils import getattr_iter
25+
from vllm.distributed import get_dp_group, get_ep_group
2526
from vllm.forward_context import ForwardContext, get_forward_context
2627
from vllm.model_executor.custom_op import CustomOp
2728
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE):
4041

4142
def __init__(self, *args, **kwargs):
4243
super().__init__(*args, **kwargs)
43-
self._top_k_index: torch.Tensor = None
44+
self._topk_ids: torch.Tensor = None
4445

4546
def custom_routing_function(hidden_states, gating_output, topk,
4647
renormalize):
47-
"""Return `top_k_weights` from `gating_output` and the
48-
`top_k_index` we stored in the layer earlier."""
49-
return gating_output, self._top_k_index
48+
"""Return `topk_weights` from `gating_output` and the
49+
`topk_ids` we stored in the layer earlier."""
50+
topk_weights = gating_output
51+
topk_ids = self._topk_ids
52+
# Handle all gather in expert parallel
53+
if topk_ids.size(0) != hidden_states.size(0):
54+
dp_metadata = get_forward_context().dp_metadata
55+
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
56+
is_sp = self.is_sequence_parallel
57+
dist_group = get_ep_group() if is_sp else get_dp_group()
58+
assert sizes[dist_group.rank_in_group] == topk_ids.shape[0]
59+
topk_ids, = dist_group.all_gatherv([topk_ids], 0, sizes)
60+
return topk_weights, topk_ids
5061

5162
self.custom_routing_function = custom_routing_function
5263

53-
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
54-
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
64+
def forward(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor,
65+
topk_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
5566
"""In Transformers `experts.forward` will have this signature.
5667
5768
We discard any extra kwargs because we cannot use them here."""
58-
return torch.ops.vllm.transformers_moe_forward(hidden_states,
59-
top_k_index,
60-
top_k_weights,
61-
self.layer_name)
69+
return torch.ops.vllm.transformers_moe_forward(
70+
hidden_states,
71+
topk_ids.to(torch.int32),
72+
topk_weights.to(torch.float32),
73+
self.layer_name,
74+
)
6275

6376

6477
def transformers_moe_forward(hidden_states: torch.Tensor,
65-
top_k_index: torch.Tensor,
66-
top_k_weights: torch.Tensor,
78+
topk_ids: torch.Tensor,
79+
topk_weights: torch.Tensor,
6780
layer_name: str) -> torch.Tensor:
68-
"""Store the `top_k_index` in the layer and call the actual forward."""
81+
"""Store the `topk_ids` in the layer and call the actual forward."""
6982
forward_context: ForwardContext = get_forward_context()
7083
self = forward_context.no_compile_layers[layer_name]
71-
self._top_k_index = top_k_index
84+
self._topk_ids = topk_ids
7285
# Clone hidden_states because it will be mutated in-place in FusedMoE
73-
return self.forward_impl(hidden_states.clone(), top_k_weights)
86+
return self.forward_impl(hidden_states.clone(), topk_weights)
7487

7588

7689
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
77-
top_k_index: torch.Tensor,
78-
top_k_weights: torch.Tensor,
90+
topk_ids: torch.Tensor,
91+
topk_weights: torch.Tensor,
7992
layer_name: str) -> torch.Tensor:
8093
return torch.empty_like(hidden_states)
8194

@@ -96,9 +109,6 @@ def __init__(self, *, vllm_config, prefix=""):
96109
self.check_version("4.57.0.dev0", "MoE models support")
97110
super().__init__(vllm_config=vllm_config, prefix=prefix)
98111

99-
if self.parallel_config.enable_expert_parallel:
100-
raise NotImplementedError(
101-
"Transformers backend does not support expert parallel yet.")
102112
if self.parallel_config.enable_eplb:
103113
raise NotImplementedError(
104114
"Transformers backend does not support expert parallel load "

0 commit comments

Comments
 (0)