66import torch .distributed as dist
77
88import vllm .envs as envs
9- from vllm .distributed import get_dp_group
9+ from vllm .distributed import get_dp_group , get_ep_group
1010from vllm .forward_context import get_forward_context
1111from vllm .logger import init_logger
1212from vllm .utils import has_deep_ep , has_pplx
@@ -34,41 +34,60 @@ def __init__(self, cpu_group):
3434 super ().__init__ (cpu_group )
3535
3636 def naive_multicast (self , x : torch .Tensor ,
37- cu_tokens_across_dp_cpu : torch .Tensor ):
37+ cu_tokens_across_sp_cpu : torch .Tensor ,
38+ is_sequence_parallel : bool ) -> torch .Tensor :
3839 assert (len (x .shape ) == 2 )
39- buffer = torch .empty ((cu_tokens_across_dp_cpu [- 1 ], x .size (1 )),
40+ buffer = torch .empty ((cu_tokens_across_sp_cpu [- 1 ], x .size (1 )),
4041 device = x .device ,
4142 dtype = x .dtype )
4243
43- start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
44- self .dp_rank - 1 ]
45- end = cu_tokens_across_dp_cpu [self .dp_rank ]
44+ rank = self .rank if is_sequence_parallel else self .dp_rank
45+ world_size = (self .world_size
46+ if is_sequence_parallel else self .dp_world_size )
47+
48+ start = 0 if rank == 0 else cu_tokens_across_sp_cpu [rank - 1 ]
49+ end = cu_tokens_across_sp_cpu [rank ]
4650 buffer [start :end , :].copy_ (x )
47- for idx in range (self . dp_world_size ):
48- start = 0 if idx == 0 else cu_tokens_across_dp_cpu [idx - 1 ]
49- end = cu_tokens_across_dp_cpu [idx ]
50- self . dp_group .broadcast (buffer [start :end , :], idx )
51+ for idx in range (world_size ):
52+ start = 0 if idx == 0 else cu_tokens_across_sp_cpu [idx - 1 ]
53+ end = cu_tokens_across_sp_cpu [idx ]
54+ get_ep_group () .broadcast (buffer [start :end , :], idx )
5155
5256 return buffer
5357
54- def dispatch (self , hidden_states : torch .Tensor ,
55- router_logits : torch .Tensor ):
56- sizes = get_forward_context (
57- ).dp_metadata .get_chunk_sizes_across_dp_rank ()
58- hidden_states , router_logits = get_dp_group ().all_gatherv (
59- [hidden_states , router_logits ],
60- dim = 0 ,
61- sizes = sizes ,
62- )
63-
58+ def dispatch (
59+ self ,
60+ hidden_states : torch .Tensor ,
61+ router_logits : torch .Tensor ,
62+ is_sequence_parallel : bool = False
63+ ) -> tuple [torch .Tensor , torch .Tensor ]:
64+ sp_size = self .tp_group .world_size if is_sequence_parallel else 1
65+ dp_metadata = get_forward_context ().dp_metadata
66+ cu_tokens_across_sp_cpu = dp_metadata .cu_tokens_across_sp (sp_size )
67+
68+ hidden_states = self .naive_multicast (hidden_states ,
69+ cu_tokens_across_sp_cpu ,
70+ is_sequence_parallel )
71+ router_logits = self .naive_multicast (router_logits ,
72+ cu_tokens_across_sp_cpu ,
73+ is_sequence_parallel )
6474 return hidden_states , router_logits
6575
66- def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
67- sizes = get_forward_context (
68- ).dp_metadata .get_chunk_sizes_across_dp_rank ()
69- hidden_states = get_dp_group ().reduce_scatterv (hidden_states ,
70- dim = 0 ,
71- sizes = sizes )
76+ def combine (self ,
77+ hidden_states : torch .Tensor ,
78+ is_sequence_parallel : bool = False ) -> torch .Tensor :
79+
80+ ep_rank = self .rank if is_sequence_parallel else self .dp_rank
81+
82+ dp_metadata = get_forward_context ().dp_metadata
83+ sp_size = self .tp_group .world_size if is_sequence_parallel else 1
84+ cu_tokens_across_sp_cpu = dp_metadata .cu_tokens_across_sp (sp_size )
85+
86+ start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu [ep_rank - 1 ]
87+ end = cu_tokens_across_sp_cpu [ep_rank ]
88+
89+ all_hidden_states = get_ep_group ().all_reduce (hidden_states )
90+ hidden_states = all_hidden_states [start :end , :]
7291 return hidden_states
7392
7493 def destroy (self ):
@@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
84103 def __init__ (self , cpu_group ):
85104 super ().__init__ (cpu_group )
86105
87- def dispatch (self , hidden_states : torch .Tensor ,
88- router_logits : torch .Tensor ):
106+ def dispatch (
107+ self ,
108+ hidden_states : torch .Tensor ,
109+ router_logits : torch .Tensor ,
110+ is_sequence_parallel : bool = False
111+ ) -> tuple [torch .Tensor , torch .Tensor ]:
89112 """
90113 Gather hidden_states and router_logits from all dp ranks.
91114 """
92115 sizes = get_forward_context (
93116 ).dp_metadata .get_chunk_sizes_across_dp_rank ()
94- hidden_states , router_logits = get_dp_group ().all_gatherv (
117+
118+ dist_group = get_ep_group () if is_sequence_parallel else get_dp_group ()
119+ assert sizes [dist_group .rank_in_group ] == hidden_states .shape [0 ]
120+ hidden_states , router_logits = dist_group .all_gatherv (
95121 [hidden_states , router_logits ],
96122 dim = 0 ,
97123 sizes = sizes ,
98124 )
99125 return hidden_states , router_logits
100126
101- def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
127+ def combine (self ,
128+ hidden_states : torch .Tensor ,
129+ is_sequence_parallel : bool = False ) -> torch .Tensor :
102130 """
103131 Reduce-scatter hidden_states across all dp ranks.
104132 """
105133 sizes = get_forward_context (
106134 ).dp_metadata .get_chunk_sizes_across_dp_rank ()
107- hidden_states = get_dp_group ().reduce_scatterv (hidden_states ,
108- dim = 0 ,
109- sizes = sizes )
135+
136+ dist_group = get_ep_group () if is_sequence_parallel else get_dp_group ()
137+ hidden_states = dist_group .reduce_scatterv (hidden_states ,
138+ dim = 0 ,
139+ sizes = sizes )
110140 return hidden_states
111141
112142 def destroy (self ):
@@ -148,11 +178,17 @@ def get_handle(self, kwargs):
148178 kwargs , pplx .AllToAll .internode
149179 if self .internode else pplx .AllToAll .intranode )
150180
151- def dispatch (self , hidden_states : torch .Tensor ,
152- router_logits : torch .Tensor ):
181+ def dispatch (
182+ self ,
183+ hidden_states : torch .Tensor ,
184+ router_logits : torch .Tensor ,
185+ is_sequence_parallel : bool = False
186+ ) -> tuple [torch .Tensor , torch .Tensor ]:
153187 raise NotImplementedError
154188
155- def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
189+ def combine (self ,
190+ hidden_states : torch .Tensor ,
191+ is_sequence_parallel : bool = False ) -> torch .Tensor :
156192 raise NotImplementedError
157193
158194 def destroy (self ):
@@ -184,11 +220,17 @@ def __init__(self, cpu_group):
184220 def get_handle (self , kwargs ):
185221 raise NotImplementedError
186222
187- def dispatch (self , hidden_states : torch .Tensor ,
188- router_logits : torch .Tensor ):
223+ def dispatch (
224+ self ,
225+ hidden_states : torch .Tensor ,
226+ router_logits : torch .Tensor ,
227+ is_sequence_parallel : bool = False
228+ ) -> tuple [torch .Tensor , torch .Tensor ]:
189229 raise NotImplementedError
190230
191- def combine (self , hidden_states : torch .Tensor ) -> torch .Tensor :
231+ def combine (self ,
232+ hidden_states : torch .Tensor ,
233+ is_sequence_parallel : bool = False ) -> torch .Tensor :
192234 raise NotImplementedError
193235
194236 def destroy (self ):
@@ -395,4 +437,4 @@ def cleanup(self):
395437 self .workspace_tensor = None
396438 self .prepare_workspace_tensor = None
397439 self .mapping = None
398- self .initialized = False
440+ self .initialized = False
0 commit comments