99 get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size )
1010from vllm .forward_context import get_forward_context
1111from vllm .model_executor .layers .fused_moe import FusedMoEConfig
12+
1213from vllm_ascend .distributed .communication_op import \
1314 data_parallel_reduce_scatter
1415
@@ -19,11 +20,17 @@ def __init__(self, moe_config: Optional[FusedMoEConfig]):
1920 self .moe_config = moe_config
2021
2122 @abstractmethod
22- def prepare (self ):
23+ def prepare (self ,
24+ hidden_states : torch .Tensor ,
25+ router_logits : torch .Tensor ,
26+ enable_shared_expert_dp : bool = False ,
27+ rm_router_logits : bool = False ,
28+ replace_allreduce : bool = False ,
29+ gate = None ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
2330 raise NotImplementedError ("Prepare not implemented." )
2431
25- @ abstractmethod
26- def finalize ( self ) :
32+ def finalize ( self , hidden_states : torch . Tensor ,
33+ reduce_results : bool ) -> torch . Tensor :
2734 raise NotImplementedError ("Combine function not implemented." )
2835
2936
@@ -91,6 +98,8 @@ def finalize(self, hidden_states: torch.Tensor,
9198
9299 Also, unpad the hidden states if needed.
93100 """
101+ assert self .moe_config .tp_group is not None , "tp_group cannot be None."
102+
94103 if not (self .enable_shared_expert_dp or self .replace_all_reduce ):
95104 if self .tp_size > 1 :
96105 dist .all_gather (list (self .split_hidden_states ), hidden_states ,
@@ -155,6 +164,8 @@ def finalize(self, hidden_states: torch.Tensor,
155164
156165 Also, unpad the hidden states if needed.
157166 """
167+ assert self .moe_config .tp_group is not None , "tp_group cannot be None."
168+
158169 if not (self .enable_shared_expert_dp or self .replace_all_reduce ):
159170 if self .tp_size > 1 :
160171 dist .all_gather (list (self .split_hidden_states ), hidden_states ,
@@ -180,9 +191,12 @@ def prepare(self,
180191 replace_all_reduce : bool = False ,
181192 gate = None ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
182193 """When DP size > 1, pad the hidden states and router logits for communication."""
194+ assert self .moe_config .dp_size is not None , "dp_size cannot be None."
195+ assert self .moe_config .dp_group is not None , "dp_group cannot be None."
196+
183197 self .rm_router_logits = rm_router_logits
184198 self .enable_shared_expert_dp = enable_shared_expert_dp
185-
199+
186200 if self .moe_config .dp_size > 1 :
187201 forward_context = get_forward_context ()
188202 max_tokens_across_dp = forward_context .max_tokens_across_dp
0 commit comments