2020import torch .distributed as dist
2121from vllm .distributed .device_communicators .base_device_communicator import \
2222 DeviceCommunicatorBase
23+ from vllm .forward_context import get_forward_context
2324
2425
2526class NPUCommunicator (DeviceCommunicatorBase ):
@@ -34,6 +35,20 @@ def __init__(self,
3435 # init device according to rank
3536 self .device = torch .npu .current_device ()
3637
38+ # Adapted from vllm/distributed/device_communicators/base_device_communicator.py
39+ if self .use_all2all :
40+ # compute some common properties
41+ from vllm .distributed .parallel_state import (get_dp_group ,
42+ get_tp_group )
43+
44+ # all2all lives in ep group, which is merged from dp and tp group
45+ self .dp_group = get_dp_group ()
46+ self .tp_group = get_tp_group ()
47+ # no self.ep_group since self.ep_group is still in construction
48+ # when we create this object
49+ self .dp_rank = self .dp_group .rank_in_group
50+ self .dp_world_size = self .dp_group .world_size
51+
3752 def all_to_all (self ,
3853 input_ : torch .Tensor ,
3954 scatter_dim : int = 0 ,
@@ -73,3 +88,43 @@ def all_to_all(self,
7388 dist .all_to_all (output_list , input_list , group = self .device_group )
7489 output_tensor = torch .cat (output_list , dim = gather_dim ).contiguous ()
7590 return output_tensor
91+
92+ def naive_multicast (self , x : torch .Tensor ,
93+ cu_tokens_across_dp_cpu : torch .Tensor ):
94+ assert (len (x .shape ) == 2 )
95+ buffer = torch .empty ((cu_tokens_across_dp_cpu [- 1 ], x .size (1 )),
96+ device = x .device ,
97+ dtype = x .dtype )
98+
99+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
100+ self .dp_rank - 1 ]
101+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
102+ buffer [start :end , :].copy_ (x )
103+ for idx in range (self .dp_world_size ):
104+ start = 0 if idx == 0 else cu_tokens_across_dp_cpu [idx - 1 ]
105+ end = cu_tokens_across_dp_cpu [idx ]
106+ self .dp_group .broadcast (buffer [start :end , :], idx )
107+
108+ return buffer
109+
110+ def dispatch (self , hidden_states : torch .Tensor ,
111+ router_logits : torch .Tensor ):
112+ cu_tokens_across_dp_cpu = get_forward_context (
113+ ).dp_metadata .cu_tokens_across_dp_cpu
114+
115+ hidden_states = self .naive_multicast (hidden_states ,
116+ cu_tokens_across_dp_cpu )
117+ router_logits = self .naive_multicast (router_logits ,
118+ cu_tokens_across_dp_cpu )
119+ return hidden_states , router_logits
120+
121+ def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
122+ cu_tokens_across_dp_cpu = get_forward_context (
123+ ).dp_metadata .cu_tokens_across_dp_cpu
124+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
125+ self .dp_rank - 1 ]
126+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
127+
128+ all_hidden_states = self .dp_group .all_reduce (hidden_states )
129+ hidden_states = all_hidden_states [start :end , :]
130+ return hidden_states
0 commit comments