Skip to content

Commit 5c6eaa4

Browse files
inkcherryloadamstohtana
authored andcommitted
Add the missing view operations from sequence parallel(async). (deepspeedai#6750)
FYI @loadams a view operation was missing in some updates compared to the original version https://github.com/microsoft/DeepSpeed/blob/17ed7c77c58611a923a6c8d2a3d21d359cd046e8/deepspeed/sequence/layer.py#L56 add missing view operation. The shape required for the view cannot be easily obtained in the current function, so refactor layout params code. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Signed-off-by: gyou2021 <ganmei.you@intel.com>
1 parent 90758a9 commit 5c6eaa4

File tree

1 file changed

+70
-59
lines changed

1 file changed

+70
-59
lines changed

deepspeed/sequence/layer.py

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,71 @@
1616
from deepspeed.utils import groups
1717

1818

19+
def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
20+
"""
21+
This function generates the parameters required for `permute` and `reshape` operations,
22+
which are used to process data before and after `all2all` communication.
23+
"""
24+
if batch_dim_idx == 0:
25+
if scatter_idx < 2:
26+
bs, global_seq_len, num_local_head, head_dim = input.shape
27+
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
28+
pre_all2all_permute_idx = (1, 0, 2, 3, 4)
29+
30+
post_all2all_permute_idx = (1, 2, 0, 3, 4)
31+
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
32+
else:
33+
bs, local_seq_len, num_total_head, head_dim = input.shape
34+
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
35+
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
36+
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
37+
38+
post_all2all_permute_idx = (1, 0, 2, 3, 4)
39+
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
40+
else:
41+
if scatter_idx < 2:
42+
global_seq_len, bs, num_local_head, head_dim = input.shape
43+
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
44+
pre_all2all_permute_idx = None
45+
46+
post_all2all_permute_idx = (1, 2, 0, 3, 4)
47+
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
48+
else:
49+
local_seq_len, bs, num_total_head, head_dim = input.shape
50+
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
51+
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
52+
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
53+
post_all2all_permute_idx = None
54+
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]
55+
56+
return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
57+
58+
59+
def post_all2all(permute_idx, res_shape):
60+
"""
61+
Post-processing function for `all2all` communication.
62+
"""
63+
64+
def post_func(input):
65+
if permute_idx is not None:
66+
input = input.permute(permute_idx).contiguous()
67+
output = input.reshape(res_shape).contiguous()
68+
69+
return output
70+
71+
return post_func
72+
73+
74+
def pre_all2all_fun(permute_idx, inp_shape, input):
75+
"""
76+
Pre-processing function for `all2all` communication.
77+
"""
78+
input_t = input.reshape(inp_shape).contiguous()
79+
if permute_idx is not None:
80+
input_t = input_t.permute(permute_idx).contiguous()
81+
return input_t
82+
83+
1984
def _rotate_half(x):
2085
"""
2186
change sign so the last dimension becomes [-odd, +even]
@@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
43108
return res
44109

45110

46-
def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):
47-
48-
def post_func(input):
49-
if batch_dim_idx == 0:
50-
# b, s, n, h
51-
if scatter_idx < 2:
52-
output = input.permute(1, 2, 0, 3, 4).contiguous()
53-
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
54-
head_dim).contiguous()
55-
else:
56-
output = input.permute(1, 0, 2, 3, 4).contiguous()
57-
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
58-
head_dim).contiguous()
59-
else:
60-
# s, b, n, h
61-
if scatter_idx < 2:
62-
output = input.permute(1, 2, 0, 3, 4).contiguous()
63-
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
64-
head_dim).contiguous()
65-
else:
66-
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
67-
return output
68-
69-
return post_func
70-
71-
72111
def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
73112
seq_world_size = dist.get_world_size(group)
74113
inp_shape = list(input.shape)
@@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
195234
assert async_op == False, "uneven head sp does not support async op"
196235
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
197236

198-
if batch_dim_idx == 0:
199-
# b, s, n, h
200-
if scatter_idx < 2:
201-
bs, global_seq_len, num_local_head, head_dim = input.shape
202-
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
203-
head_dim]).contiguous()
204-
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
205-
else:
206-
bs, local_seq_len, num_total_head, head_dim = input.shape
207-
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
208-
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
209-
head_dim]).contiguous()
210-
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
211-
else:
212-
# s, b, n, h
213-
if scatter_idx < 2:
214-
global_seq_len, bs, num_local_head, head_dim = input.shape
215-
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
216-
head_dim]).contiguous()
217-
else:
218-
local_seq_len, bs, num_total_head, head_dim = input.shape
219-
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
220-
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
221-
head_dim]).contiguous()
222-
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
237+
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
238+
scatter_idx, batch_dim_idx, seq_world_size, input)
223239

224-
if scatter_idx < 2:
225-
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
226-
head_dim)
227-
else:
228-
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
229-
head_dim)
240+
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
230241

242+
post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
231243
output = torch.empty_like(input_t)
232244
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
233245

@@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
236248
handle[type + '_work'] = work
237249
handle[type + '_grad'] = output
238250
handle[type + '_post_all2all_func'] = post_all2all_fun
239-
return output
251+
return output.view(post_all2all_res_shape)
240252

241253
res = post_all2all_fun(output)
242254
return res
@@ -271,7 +283,6 @@ def forward(ctx: Any,
271283
assert ctx.stream != None
272284
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
273285
get_accelerator().current_stream().wait_stream(ctx.stream)
274-
del ctx.stream.activation_buffer_list
275286
# The computation of d o_weight can overlap with the communication of d o_input
276287

277288
elif not is_fwd and type in ('q', 'k'):

0 commit comments

Comments
 (0)