2020import torch .distributed as dist
2121from vllm .distributed .device_communicators .base_device_communicator import \
2222 DeviceCommunicatorBase
23+ from vllm .utils import logger
2324
2425
2526class NPUCommunicator (DeviceCommunicatorBase ):
@@ -34,6 +35,12 @@ def __init__(self,
3435 # init device according to rank
3536 self .device = torch .npu .current_device ()
3637
38+ if self .use_all2all :
39+ from vllm .distributed .device_communicators .all2all import \
40+ NaiveAll2AllManager
41+ self .all2all_manager = NaiveAll2AllManager (self .cpu_group )
42+ logger .info ("Using naive all2all manager." )
43+
3744 def all_to_all (self ,
3845 input_ : torch .Tensor ,
3946 scatter_dim : int = 0 ,
@@ -73,3 +80,17 @@ def all_to_all(self,
7380 dist .all_to_all (output_list , input_list , group = self .device_group )
7481 output_tensor = torch .cat (output_list , dim = gather_dim ).contiguous ()
7582 return output_tensor
83+
84+ # TODO: Add ut for dispatch and combine
85+ def dispatch (
86+ self , hidden_states : torch .Tensor ,
87+ router_logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
88+ assert self .all2all_manager is not None
89+ hidden_states , router_logits = self .all2all_manager .dispatch (
90+ hidden_states , router_logits )
91+ return hidden_states , router_logits
92+
93+ def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
94+ assert self .all2all_manager is not None
95+ hidden_states = self .all2all_manager .combine (hidden_states )
96+ return hidden_states
0 commit comments