1616from deepspeed .utils import groups
1717
1818
19+ def _generate_layout_params (scatter_idx , batch_dim_idx , seq_world_size , input ):
20+ """
21+ This function generates the parameters required for `permute` and `reshape` operations,
22+ which are used to process data before and after `all2all` communication.
23+ """
24+ if batch_dim_idx == 0 :
25+ if scatter_idx < 2 :
26+ bs , global_seq_len , num_local_head , head_dim = input .shape
27+ pre_all2all_inp_shape = [bs , seq_world_size , global_seq_len // seq_world_size , num_local_head , head_dim ]
28+ pre_all2all_permute_idx = (1 , 0 , 2 , 3 , 4 )
29+
30+ post_all2all_permute_idx = (1 , 2 , 0 , 3 , 4 )
31+ post_all2all_res_shape = [bs , global_seq_len // seq_world_size , seq_world_size * num_local_head , head_dim ]
32+ else :
33+ bs , local_seq_len , num_total_head , head_dim = input .shape
34+ assert num_total_head % seq_world_size == 0 , f"Number of heads ({ num_total_head } ) must be divisible by the sequence parallel size ({ seq_world_size } )!"
35+ pre_all2all_inp_shape = [bs , local_seq_len , seq_world_size , num_total_head // seq_world_size , head_dim ]
36+ pre_all2all_permute_idx = (2 , 0 , 1 , 3 , 4 )
37+
38+ post_all2all_permute_idx = (1 , 0 , 2 , 3 , 4 )
39+ post_all2all_res_shape = [bs , seq_world_size * local_seq_len , num_total_head // seq_world_size , head_dim ]
40+ else :
41+ if scatter_idx < 2 :
42+ global_seq_len , bs , num_local_head , head_dim = input .shape
43+ pre_all2all_inp_shape = [seq_world_size , global_seq_len // seq_world_size , bs , num_local_head , head_dim ]
44+ pre_all2all_permute_idx = None
45+
46+ post_all2all_permute_idx = (1 , 2 , 0 , 3 , 4 )
47+ post_all2all_res_shape = [bs , seq_world_size * global_seq_len , num_local_head // seq_world_size , head_dim ]
48+ else :
49+ local_seq_len , bs , num_total_head , head_dim = input .shape
50+ assert num_total_head % seq_world_size == 0 , f"Number of heads ({ num_total_head } ) must be divisible by the sequence parallel size ({ seq_world_size } )!"
51+ pre_all2all_inp_shape = [local_seq_len , bs , seq_world_size , num_total_head // seq_world_size , head_dim ]
52+ pre_all2all_permute_idx = (2 , 0 , 1 , 3 , 4 )
53+ post_all2all_permute_idx = None
54+ post_all2all_res_shape = [local_seq_len * seq_world_size , bs , num_total_head // seq_world_size , head_dim ]
55+
56+ return pre_all2all_permute_idx , pre_all2all_inp_shape , post_all2all_permute_idx , post_all2all_res_shape
57+
58+
59+ def post_all2all (permute_idx , res_shape ):
60+ """
61+ Post-processing function for `all2all` communication.
62+ """
63+
64+ def post_func (input ):
65+ if permute_idx is not None :
66+ input = input .permute (permute_idx ).contiguous ()
67+ output = input .reshape (res_shape ).contiguous ()
68+
69+ return output
70+
71+ return post_func
72+
73+
74+ def pre_all2all_fun (permute_idx , inp_shape , input ):
75+ """
76+ Pre-processing function for `all2all` communication.
77+ """
78+ input_t = input .reshape (inp_shape ).contiguous ()
79+ if permute_idx is not None :
80+ input_t = input_t .permute (permute_idx ).contiguous ()
81+ return input_t
82+
83+
1984def _rotate_half (x ):
2085 """
2186 change sign so the last dimension becomes [-odd, +even]
@@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
43108 return res
44109
45110
46- def post_all2all (scatter_idx , batch_dim_idx , seq_world_size , bs , seq_len , num_head , head_dim ):
47-
48- def post_func (input ):
49- if batch_dim_idx == 0 :
50- # b, s, n, h
51- if scatter_idx < 2 :
52- output = input .permute (1 , 2 , 0 , 3 , 4 ).contiguous ()
53- output = output .reshape (bs , seq_len // seq_world_size , seq_world_size * num_head ,
54- head_dim ).contiguous ()
55- else :
56- output = input .permute (1 , 0 , 2 , 3 , 4 ).contiguous ()
57- output = output .reshape (bs , seq_world_size * seq_len , num_head // seq_world_size ,
58- head_dim ).contiguous ()
59- else :
60- # s, b, n, h
61- if scatter_idx < 2 :
62- output = input .permute (1 , 2 , 0 , 3 , 4 ).contiguous ()
63- output = output .reshape (seq_len // seq_world_size , bs , seq_world_size * num_head ,
64- head_dim ).contiguous ()
65- else :
66- output = input .reshape (seq_len * seq_world_size , bs , num_head // seq_world_size , head_dim ).contiguous ()
67- return output
68-
69- return post_func
70-
71-
72111def uneven_heads_all2all (input , scatter_idx , gather_idx , batch_dim_idx , group ):
73112 seq_world_size = dist .get_world_size (group )
74113 inp_shape = list (input .shape )
@@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
195234 assert async_op == False , "uneven head sp does not support async op"
196235 return uneven_heads_all2all (input , scatter_idx , gather_idx , batch_dim_idx , group )
197236
198- if batch_dim_idx == 0 :
199- # b, s, n, h
200- if scatter_idx < 2 :
201- bs , global_seq_len , num_local_head , head_dim = input .shape
202- input_t = input .reshape ([bs , seq_world_size , global_seq_len // seq_world_size , num_local_head ,
203- head_dim ]).contiguous ()
204- input_t = input_t .permute (1 , 0 , 2 , 3 , 4 ).contiguous ()
205- else :
206- bs , local_seq_len , num_total_head , head_dim = input .shape
207- assert num_total_head % seq_world_size == 0 , f"Number of heads ({ num_total_head } ) must be divisible by the sequence parallel size ({ seq_world_size } )!"
208- input_t = input .reshape ([bs , local_seq_len , seq_world_size , num_total_head // seq_world_size ,
209- head_dim ]).contiguous ()
210- input_t = input_t .permute (2 , 0 , 1 , 3 , 4 ).contiguous ()
211- else :
212- # s, b, n, h
213- if scatter_idx < 2 :
214- global_seq_len , bs , num_local_head , head_dim = input .shape
215- input_t = input .reshape ([seq_world_size , global_seq_len // seq_world_size , bs , num_local_head ,
216- head_dim ]).contiguous ()
217- else :
218- local_seq_len , bs , num_total_head , head_dim = input .shape
219- assert num_total_head % seq_world_size == 0 , f"Number of heads ({ num_total_head } ) must be divisible by the sequence parallel size ({ seq_world_size } )!"
220- input_t = input .reshape ([local_seq_len , bs , seq_world_size , num_total_head // seq_world_size ,
221- head_dim ]).contiguous ()
222- input_t = input_t .permute (2 , 0 , 1 , 3 , 4 ).contiguous ()
237+ pre_all2all_permute_idx , pre_all2all_inp_shape , post_all2all_permute_idx , post_all2all_res_shape = _generate_layout_params (
238+ scatter_idx , batch_dim_idx , seq_world_size , input )
223239
224- if scatter_idx < 2 :
225- post_all2all_fun = post_all2all (scatter_idx , batch_dim_idx , seq_world_size , bs , global_seq_len , num_local_head ,
226- head_dim )
227- else :
228- post_all2all_fun = post_all2all (scatter_idx , batch_dim_idx , seq_world_size , bs , local_seq_len , num_total_head ,
229- head_dim )
240+ input_t = pre_all2all_fun (pre_all2all_permute_idx , pre_all2all_inp_shape , input )
230241
242+ post_all2all_fun = post_all2all (post_all2all_permute_idx , post_all2all_res_shape )
231243 output = torch .empty_like (input_t )
232244 work = dist .all_to_all_single (output , input_t , group = group , async_op = async_op )
233245
@@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
236248 handle [type + '_work' ] = work
237249 handle [type + '_grad' ] = output
238250 handle [type + '_post_all2all_func' ] = post_all2all_fun
239- return output
251+ return output . view ( post_all2all_res_shape )
240252
241253 res = post_all2all_fun (output )
242254 return res
@@ -271,7 +283,6 @@ def forward(ctx: Any,
271283 assert ctx .stream != None
272284 res = single_all_to_all (input , scatter_idx , gather_idx , batch_dim_idx , group , False )
273285 get_accelerator ().current_stream ().wait_stream (ctx .stream )
274- del ctx .stream .activation_buffer_list
275286 # The computation of d o_weight can overlap with the communication of d o_input
276287
277288 elif not is_fwd and type in ('q' , 'k' ):
0 commit comments