From cb15ffa15304dfe2373f6bcc5353b397ecb93ed3 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 21 May 2024 07:37:29 +0000 Subject: [PATCH 01/16] fix ds-sp grad scale for zero0 --- deepspeed/runtime/engine.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 13f335cae6d5..95eb5d72e316 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2411,20 +2411,24 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): else: dp_group = groups._get_sequence_data_parallel_group() + dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group) + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple - self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) + self.allreduce_no_retain(dense_bucket, + dp_group=dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): # to maintain the gradients value unaffected by ep_size setting, # utilize dp_world_size for allreduce average - dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) + dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) / float(self.sequence_parallel_size) for ep_name, expert_grads_group in expert_grads.items(): ep_dp_group = groups._get_expert_data_parallel_group(ep_name) split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( @@ -2491,9 +2495,9 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size))) + values.mul_(self.gradient_predivide_factor() / (dp_world_size)) else: - values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size))) + values.mul_(1. / (dp_world_size)) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) From a037a53fb9687c1d39d6a6eaee8d0b5850340794 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 18 Jun 2024 13:38:41 +0000 Subject: [PATCH 02/16] enable o compute async --- deepspeed/sequence/layer.py | 80 ++++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index e1dbff87f4ec..6f0f8350b0eb 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -10,9 +10,16 @@ from torch.nn import Module import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator -def single_all_to_all(input, scatter_idx, gather_idx, group): +def wait_stream(stream): + get_accelerator().wait_stream(stream) + +def print0(msg): + if dist.get_rank()==0: + print(msg) +def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size @@ -29,8 +36,8 @@ def single_all_to_all(input, scatter_idx, gather_idx, group): ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) - + work= dist.all_to_all_single(output, input_t, group=group, async_op=async_op) + # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_idx < 2: output = output.transpose(0, 1).contiguous() @@ -38,23 +45,45 @@ def single_all_to_all(input, scatter_idx, gather_idx, group): return output.reshape( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() + inp_shape[gather_idx + 1:]).contiguous() ,work class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx - - return single_all_to_all(input, scatter_idx, gather_idx, group) + ctx.stream=stream + ctx.bwd_async=bwd_async + # if stream != None: + # with get_accelerator().stream(stream): + # res, work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) + # else: + # res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) + # def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): + + if fwd_async: + get_accelerator().current_stream().wait_stream(ctx.stream) + + return res + @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + # print0("all2all o before") + # import pydevd + # pydevd.settrace() + + #def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: + q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False), None, None,None,None,None) + # print0("all2all o after") + + return q class DistributedAttention(torch.nn.Module): @@ -73,6 +102,7 @@ def __init__( sequence_process_group: dist.ProcessGroup, scatter_idx: int = 2, gather_idx: int = 0, + sp_stream=None ) -> None: super(DistributedAttention, self).__init__() @@ -80,7 +110,13 @@ def __init__( self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx + # self.q_stream=get_accelerator().Stream() + # self.k_stream=get_accelerator().Stream() + # self.v_stream=get_accelerator().Stream() + self.sp_stream=sp_stream + b=0 + # query = slef.linearq(hidden) def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: """ forward @@ -96,14 +132,30 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + + #step1 get q ,k ,v outside out this function + # def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: - #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) #[1,512,32,32] + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) #[1,512,32,32] + value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + #out shape : e.g., [s:h/p:] + # print(query_layer) #2,8, 2,4 sp=2 2gpus + # # + # print(key_layer) + # print(value_layer) #seq_len 16 , sp 2 , head_dim = 4, num_heads=4, hidding=16 + + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) #[8,512,4,32] + bwd_o_async=False + if self.sp_stream is not None: + bwd_o_async=True + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,False,bwd_o_async) + + # dO=wdY #out e.g., [s/p::h] return output + + #o= self.dense(output) \ No newline at end of file From 42d12849dfbc7ed021cb799de2bf2e44de3ff1f4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Jun 2024 07:14:54 +0000 Subject: [PATCH 03/16] enable qk bwd async all2all --- deepspeed/sequence/layer.py | 167 +++++++++++++++++++++++++++++++++--- 1 file changed, 153 insertions(+), 14 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 6f0f8350b0eb..82868f70e91b 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -19,7 +19,7 @@ def wait_stream(stream): def print0(msg): if dist.get_rank()==0: print(msg) -def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): +def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,handle=None,type=None): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size @@ -37,21 +37,39 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): output = torch.empty_like(input_t) work= dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - + # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_idx < 2: output = output.transpose(0, 1).contiguous() - - return output.reshape( + if async_op: + + # work.wait() + # if(dist.get_rank()==0): + # k=0 + c=output.reshape( + inp_shape[: gather_idx] + \ + [inp_shape[gather_idx] * seq_world_size,] + \ + inp_shape[gather_idx + 1:]).contiguous() + # return c,work + # qq = torch.empty_like(c,device='meta') + handle[type+'_grad']=output + # if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + # b=0 + return c, work + #!! need to delete + c= output.reshape( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() ,work + inp_shape[gather_idx + 1:]).contiguous() + return c,work class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx @@ -59,16 +77,30 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, ctx.gather_idx = gather_idx ctx.stream=stream ctx.bwd_async=bwd_async + ctx.handle=handle + ctx.type=type # if stream != None: # with get_accelerator().stream(stream): # res, work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) # else: # res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) # def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): + # print0(f"fwd_async:{fwd_async},type:{type},handle:{handle}") + if fwd_async and stream!=None: + # print0('11') + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) - if fwd_async: get_accelerator().current_stream().wait_stream(ctx.stream) + elif fwd_async and handle!=None: + # print0('22') + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async,handle,type) + # import pydevd + # pydevd.settrace() + handle[type]=work + b=0 + else: + # print0('33') + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) return res @@ -80,7 +112,7 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: # pydevd.settrace() #def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: - q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False), None, None,None,None,None) + q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False,ctx.handle,ctx.type), None, None,None,None,None,None,None) # print0("all2all o after") return q @@ -102,7 +134,9 @@ def __init__( sequence_process_group: dist.ProcessGroup, scatter_idx: int = 2, gather_idx: int = 0, - sp_stream=None + sp_stream=None, + q_linear=None, + k_linear=None ) -> None: super(DistributedAttention, self).__init__() @@ -114,7 +148,27 @@ def __init__( # self.k_stream=get_accelerator().Stream() # self.v_stream=get_accelerator().Stream() self.sp_stream=sp_stream - b=0 + self.bwd_all2all_handels={} + self.bwd_all2all_handels['dq']=None + self.bwd_all2all_handels['dq_grad']=None + self.bwd_all2all_handels['dk']=None + self.bwd_all2all_handels['dk_grad']=None + + self.q_linear=q_linear + self.k_linear=k_linear + self.hook_register=False + # def q_hook(module, grad_input): + + # grad_input= grad_input.reshape( + # inp_shape[: scatter_idx] + \ + # [inp_shape[scatter_idx] * seq_world_size,] + \ + # inp_shape[scatter_idx + 1:]).contiguous() + + + # q_linear.register_full_backward_pre_hook(q_hook) + # k_linear.register_full_backward_pre_hook(k_hook) + + # query = slef.linearq(hidden) def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: @@ -134,13 +188,98 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens #in shape : e.g., [s/p:h:] + + #step1 get q ,k ,v outside out this function # def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: + + + def q_hook(*notneeded): - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) #[1,512,32,32] - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) #[1,512,32,32] - value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] + #4096 1 2046 + # print("hookq") + # grad_input.contiguous() + # import pydevd + # pydevd.settrace() + self.bwd_all2all_handels['dq'].wait() + tmp=self.bwd_all2all_handels['dq_grad'] + notneeded=list(notneeded) + notneeded[0]=list(notneeded[0]) + notneeded[0][0]=tmp.reshape(1,4096,16,128).contiguous() + if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + b=0 + notneeded[0]=tuple(notneeded[0]) + notneeded=tuple(notneeded) + + # for_check=tmp.reshape(4096,1,16,128).contiguous() + # assert torch.equal(for_check,notneeded[0][0]) + + # print0("pass q") + if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + b=0 + # notneeded[0]=tuple(notneeded[0]) + # notneeded=tuple(notneeded) + # notneeded[0][0]=notneeded[0][0].reshape( + # inp_shape[: gather_idx] + \ + # [inp_shape[gather_idx] * seq_world_size,] + \ + # inp_shape[gather_idx + 1:]).contiguous() + # grad_input= grad_input.reshape( + # inp_shape[: scatter_idx] + \ + # [inp_shape[scatter_idx] * seq_world_size,] + \ + # inp_shape[scatter_idx + 1:]).contiguous() + # def k_hook(module, grad_input): + def k_hook(*notneeded): + # self.bwd_all2all_handels['dk'].wait() + tmp=self.bwd_all2all_handels['dk_grad'] + notneeded=list(notneeded) + notneeded[0]=list(notneeded[0]) + notneeded[0][0]=tmp.reshape(1,4096,16,128).contiguous() + if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + b=0 + notneeded[0]=tuple(notneeded[0]) + notneeded=tuple(notneeded) + + + # for_check=tmp.reshape(4096,1,16,128).contiguous() + # assert torch.equal(for_check,notneeded[0][0]) + # print0("pass k") + # print("hookk") + # grad_input.contiguous() + b=0 + + async_bwd_comm_q=False + async_bwd_comm_k=False + # if self.q_linear!=None: + # async_bwd_comm_q=True + # if self.k_linear!=None: + # async_bwd_comm_k=True + # if self.hook_register==False: + # if True: + if True: + async_bwd_comm_q=True + async_bwd_comm_k=True + #eval interval + fn_q = query.grad_fn.next_functions[0][0] + fn_q.register_prehook(q_hook) + fn_k = key.grad_fn.next_functions[0][0] + fn_k.register_prehook(k_hook) + + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'dq') #[1,512,32,32] + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'dk') #[1,512,32,32] + + + value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] + ##all2all ayns to v_dense_bwd wait + + + #all2all ayns to k_dense_bwd wait #out shape : e.g., [s:h/p:] # print(query_layer) #2,8, 2,4 sp=2 2gpus # # From 6919af43f09c7791ae032ca3d35007bd2ca5601c Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 06:09:54 +0000 Subject: [PATCH 04/16] fwd optimi --- deepspeed/sequence/layer.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 82868f70e91b..021682ea3bf8 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -201,7 +201,10 @@ def q_hook(*notneeded): # grad_input.contiguous() # import pydevd # pydevd.settrace() + # torch.cuda.default_stream().wait_stream(self.sp_stream) self.bwd_all2all_handels['dq'].wait() + self.sp_stream.wait_stream(torch.cuda.default_stream()) + tmp=self.bwd_all2all_handels['dq_grad'] notneeded=list(notneeded) notneeded[0]=list(notneeded[0]) @@ -217,10 +220,10 @@ def q_hook(*notneeded): # assert torch.equal(for_check,notneeded[0][0]) # print0("pass q") - if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - b=0 + # if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + # b=0 # notneeded[0]=tuple(notneeded[0]) # notneeded=tuple(notneeded) # notneeded[0][0]=notneeded[0][0].reshape( @@ -233,7 +236,11 @@ def q_hook(*notneeded): # inp_shape[scatter_idx + 1:]).contiguous() # def k_hook(module, grad_input): def k_hook(*notneeded): - # self.bwd_all2all_handels['dk'].wait() + # torch.cuda.default_stream().wait_stream(self.sp_stream) + self.bwd_all2all_handels['dk'].wait() + self.sp_stream.wait_stream(torch.cuda.default_stream()) + + tmp=self.bwd_all2all_handels['dk_grad'] notneeded=list(notneeded) notneeded[0]=list(notneeded[0]) @@ -270,11 +277,13 @@ def k_hook(*notneeded): fn_q.register_prehook(q_hook) fn_k = key.grad_fn.next_functions[0][0] fn_k.register_prehook(k_hook) - + + torch.cuda.current_stream().wait_event(query.done_event) query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'dq') #[1,512,32,32] + torch.cuda.current_stream().wait_event(key.done_event) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'dk') #[1,512,32,32] - - + # torch.cuda.current_stream().wait_event(value.done_event) + torch.cuda.current_stream().wait_stream(self.sp_stream) value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] ##all2all ayns to v_dense_bwd wait From 39596ac41c885b25212e051cfd73b63d67492b3a Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 08:11:46 +0000 Subject: [PATCH 05/16] fix1 remove linear arg, remove note --- deepspeed/sequence/layer.py | 91 +++++-------------------------------- 1 file changed, 11 insertions(+), 80 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 021682ea3bf8..8303dc484cee 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -43,20 +43,13 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand output = output.transpose(0, 1).contiguous() if async_op: - # work.wait() - # if(dist.get_rank()==0): - # k=0 c=output.reshape( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]).contiguous() - # return c,work - # qq = torch.empty_like(c,device='meta') + handle[type+'_grad']=output - # if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - # b=0 + return c, work #!! need to delete c= output.reshape( @@ -79,27 +72,17 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, ctx.bwd_async=bwd_async ctx.handle=handle ctx.type=type - # if stream != None: - # with get_accelerator().stream(stream): - # res, work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) - # else: - # res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async) - # def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False): - # print0(f"fwd_async:{fwd_async},type:{type},handle:{handle}") + if fwd_async and stream!=None: - # print0('11') + # print0('') res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) get_accelerator().current_stream().wait_stream(ctx.stream) elif fwd_async and handle!=None: - # print0('22') res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async,handle,type) - # import pydevd - # pydevd.settrace() + handle[type]=work - b=0 else: - # print0('33') res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) return res @@ -135,8 +118,6 @@ def __init__( scatter_idx: int = 2, gather_idx: int = 0, sp_stream=None, - q_linear=None, - k_linear=None ) -> None: super(DistributedAttention, self).__init__() @@ -144,9 +125,8 @@ def __init__( self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx - # self.q_stream=get_accelerator().Stream() - # self.k_stream=get_accelerator().Stream() - # self.v_stream=get_accelerator().Stream() + + self.sp_stream=sp_stream self.bwd_all2all_handels={} self.bwd_all2all_handels['dq']=None @@ -154,20 +134,9 @@ def __init__( self.bwd_all2all_handels['dk']=None self.bwd_all2all_handels['dk_grad']=None - self.q_linear=q_linear - self.k_linear=k_linear + self.hook_register=False - # def q_hook(module, grad_input): - - # grad_input= grad_input.reshape( - # inp_shape[: scatter_idx] + \ - # [inp_shape[scatter_idx] * seq_world_size,] + \ - # inp_shape[scatter_idx + 1:]).contiguous() - - - # q_linear.register_full_backward_pre_hook(q_hook) - # k_linear.register_full_backward_pre_hook(k_hook) - + # query = slef.linearq(hidden) @@ -196,12 +165,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens def q_hook(*notneeded): - #4096 1 2046 - # print("hookq") - # grad_input.contiguous() - # import pydevd - # pydevd.settrace() - # torch.cuda.default_stream().wait_stream(self.sp_stream) self.bwd_all2all_handels['dq'].wait() self.sp_stream.wait_stream(torch.cuda.default_stream()) @@ -216,25 +179,7 @@ def q_hook(*notneeded): notneeded[0]=tuple(notneeded[0]) notneeded=tuple(notneeded) - # for_check=tmp.reshape(4096,1,16,128).contiguous() - # assert torch.equal(for_check,notneeded[0][0]) - - # print0("pass q") - # if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - # b=0 - # notneeded[0]=tuple(notneeded[0]) - # notneeded=tuple(notneeded) - # notneeded[0][0]=notneeded[0][0].reshape( - # inp_shape[: gather_idx] + \ - # [inp_shape[gather_idx] * seq_world_size,] + \ - # inp_shape[gather_idx + 1:]).contiguous() - # grad_input= grad_input.reshape( - # inp_shape[: scatter_idx] + \ - # [inp_shape[scatter_idx] * seq_world_size,] + \ - # inp_shape[scatter_idx + 1:]).contiguous() - # def k_hook(module, grad_input): + def k_hook(*notneeded): # torch.cuda.default_stream().wait_stream(self.sp_stream) self.bwd_all2all_handels['dk'].wait() @@ -253,22 +198,11 @@ def k_hook(*notneeded): notneeded=tuple(notneeded) - # for_check=tmp.reshape(4096,1,16,128).contiguous() - # assert torch.equal(for_check,notneeded[0][0]) - # print0("pass k") - # print("hookk") - # grad_input.contiguous() - b=0 async_bwd_comm_q=False async_bwd_comm_k=False - # if self.q_linear!=None: - # async_bwd_comm_q=True - # if self.k_linear!=None: - # async_bwd_comm_k=True - # if self.hook_register==False: - # if True: + if True: async_bwd_comm_q=True async_bwd_comm_k=True @@ -282,10 +216,8 @@ def k_hook(*notneeded): query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'dq') #[1,512,32,32] torch.cuda.current_stream().wait_event(key.done_event) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'dk') #[1,512,32,32] - # torch.cuda.current_stream().wait_event(value.done_event) torch.cuda.current_stream().wait_stream(self.sp_stream) value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] - ##all2all ayns to v_dense_bwd wait #all2all ayns to k_dense_bwd wait @@ -301,7 +233,6 @@ def k_hook(*notneeded): bwd_o_async=True output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,False,bwd_o_async) - # dO=wdY #out e.g., [s/p::h] return output From eb760c019b393afb5b2eb7615209ccfe9395b74d Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 12:12:22 +0000 Subject: [PATCH 06/16] async qkv fwd, optimi cpu ,make fwd call fast --- deepspeed/sequence/layer.py | 65 ++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 8303dc484cee..c73feaf395ca 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -47,8 +47,8 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]).contiguous() - - handle[type+'_grad']=output + if type=='dq' or type=='dk': + handle[type+'_grad']=output return c, work #!! need to delete @@ -62,7 +62,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None) -> Tensor: + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx @@ -73,14 +73,24 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, ctx.handle=handle ctx.type=type - if fwd_async and stream!=None: + # if fwd_async and stream!=None: + if not is_fwd and type=='o': + assert stream!=None # print0('') res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) get_accelerator().current_stream().wait_stream(ctx.stream) - elif fwd_async and handle!=None: - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,fwd_async,handle,type) + # elif fwd_async and handle!=None: + elif not is_fwd and (type=='q' or type=='k'): + assert fwd_async==True + type='d'+type + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type) + handle[type]=work + elif is_fwd and (type=='q' or type=='k'): + type='fwd_'+type + + res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type) handle[type]=work else: res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) @@ -95,7 +105,7 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: # pydevd.settrace() #def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: - q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False,ctx.handle,ctx.type), None, None,None,None,None,None,None) + q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False,ctx.handle,ctx.type,False), None, None,None,None,None,None,None,None) # print0("all2all o after") return q @@ -133,7 +143,7 @@ def __init__( self.bwd_all2all_handels['dq_grad']=None self.bwd_all2all_handels['dk']=None self.bwd_all2all_handels['dk_grad']=None - + self.dafult_stream=get_accelerator().default_stream() self.hook_register=False @@ -182,6 +192,12 @@ def q_hook(*notneeded): def k_hook(*notneeded): # torch.cuda.default_stream().wait_stream(self.sp_stream) + + # if(dist.get_rank()==0): + # import pydevd + # pydevd.settrace() + + self.bwd_all2all_handels['dk'].wait() self.sp_stream.wait_stream(torch.cuda.default_stream()) @@ -190,19 +206,25 @@ def k_hook(*notneeded): notneeded=list(notneeded) notneeded[0]=list(notneeded[0]) notneeded[0][0]=tmp.reshape(1,4096,16,128).contiguous() - if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - b=0 + notneeded[0]=tuple(notneeded[0]) notneeded=tuple(notneeded) - async_bwd_comm_q=False - async_bwd_comm_k=False + async_bwd_comm_q=True + async_bwd_comm_k=True + + + self.dafult_stream.wait_event(query.done_event) + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'q') #[1,512,32,32] + self.dafult_stream.wait_event(key.done_event) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'k') #[1,512,32,32] + self.dafult_stream.wait_stream(self.sp_stream) + value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False,False, self.bwd_all2all_handels,'v') #[1,512,32,32] + if True: async_bwd_comm_q=True async_bwd_comm_k=True @@ -211,15 +233,12 @@ def k_hook(*notneeded): fn_q.register_prehook(q_hook) fn_k = key.grad_fn.next_functions[0][0] fn_k.register_prehook(k_hook) - - torch.cuda.current_stream().wait_event(query.done_event) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'dq') #[1,512,32,32] - torch.cuda.current_stream().wait_event(key.done_event) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'dk') #[1,512,32,32] - torch.cuda.current_stream().wait_stream(self.sp_stream) - value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #[1,512,32,32] - - + #do dq qk k v + # def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None) -> Tensor: + + self.bwd_all2all_handels['fwd_q'].wait() + self.bwd_all2all_handels['fwd_k'].wait() + # self.bwd_all2all_handels['fwd_q'].wait() #all2all ayns to k_dense_bwd wait #out shape : e.g., [s:h/p:] # print(query_layer) #2,8, 2,4 sp=2 2gpus From c7d3374a46e0dfc543eef66a785d72f6cc13d250 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 13:27:08 +0000 Subject: [PATCH 07/16] update --- deepspeed/sequence/layer.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index c73feaf395ca..c10155a35203 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -42,14 +42,13 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand if scatter_idx < 2: output = output.transpose(0, 1).contiguous() if async_op: - - c=output.reshape( - inp_shape[: gather_idx] + \ + shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() + inp_shape[gather_idx + 1:]) + c=output.reshape(shape).contiguous() if type=='dq' or type=='dk': handle[type+'_grad']=output - + handle[type+'_grad_shape']=shape return c, work #!! need to delete c= output.reshape( @@ -181,32 +180,18 @@ def q_hook(*notneeded): tmp=self.bwd_all2all_handels['dq_grad'] notneeded=list(notneeded) notneeded[0]=list(notneeded[0]) - notneeded[0][0]=tmp.reshape(1,4096,16,128).contiguous() - if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - b=0 + notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dq_grad_shape']).contiguous() notneeded[0]=tuple(notneeded[0]) notneeded=tuple(notneeded) - def k_hook(*notneeded): - # torch.cuda.default_stream().wait_stream(self.sp_stream) - - # if(dist.get_rank()==0): - # import pydevd - # pydevd.settrace() - - + def k_hook(*notneeded): self.bwd_all2all_handels['dk'].wait() self.sp_stream.wait_stream(torch.cuda.default_stream()) - - tmp=self.bwd_all2all_handels['dk_grad'] notneeded=list(notneeded) notneeded[0]=list(notneeded[0]) - notneeded[0][0]=tmp.reshape(1,4096,16,128).contiguous() - + notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dk_grad_shape']).contiguous() notneeded[0]=tuple(notneeded[0]) notneeded=tuple(notneeded) From 70a6d0c9944d731366710662ec6440c257c155bc Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 14:10:13 +0000 Subject: [PATCH 08/16] refine code --- deepspeed/sequence/layer.py | 108 +++++++++++++----------------------- 1 file changed, 40 insertions(+), 68 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index c10155a35203..f00f9a051614 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -45,11 +45,11 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]) - c=output.reshape(shape).contiguous() + res=output.reshape(shape).contiguous() if type=='dq' or type=='dk': handle[type+'_grad']=output handle[type+'_grad_shape']=shape - return c, work + return res, work #!! need to delete c= output.reshape( inp_shape[: gather_idx] + \ @@ -61,7 +61,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor: + def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx @@ -72,16 +72,12 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, ctx.handle=handle ctx.type=type - # if fwd_async and stream!=None: if not is_fwd and type=='o': assert stream!=None - # print0('') res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) get_accelerator().current_stream().wait_stream(ctx.stream) - # elif fwd_async and handle!=None: elif not is_fwd and (type=='q' or type=='k'): - assert fwd_async==True type='d'+type res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type) @@ -99,15 +95,11 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - # print0("all2all o before") - # import pydevd - # pydevd.settrace() + + - #def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: - q= (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,ctx.bwd_async,False,ctx.handle,ctx.type,False), None, None,None,None,None,None,None,None) - # print0("all2all o after") - return q + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,False,ctx.handle,ctx.type,False), None,None,None,None,None,None,None) class DistributedAttention(torch.nn.Module): @@ -137,18 +129,15 @@ def __init__( self.sp_stream=sp_stream - self.bwd_all2all_handels={} - self.bwd_all2all_handels['dq']=None - self.bwd_all2all_handels['dq_grad']=None - self.bwd_all2all_handels['dk']=None - self.bwd_all2all_handels['dk_grad']=None + self.overlap_handles={} + self.overlap_handles['dq']=None + self.overlap_handles['dq_grad']=None + self.overlap_handles['dk']=None + self.overlap_handles['dk_grad']=None self.dafult_stream=get_accelerator().default_stream() - - self.hook_register=False - # query = slef.linearq(hidden) def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: """ forward @@ -169,31 +158,21 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens #step1 get q ,k ,v outside out this function - # def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False) -> Tensor: - - - def q_hook(*notneeded): - - self.bwd_all2all_handels['dq'].wait() - self.sp_stream.wait_stream(torch.cuda.default_stream()) - - tmp=self.bwd_all2all_handels['dq_grad'] - notneeded=list(notneeded) - notneeded[0]=list(notneeded[0]) - notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dq_grad_shape']).contiguous() - notneeded[0]=tuple(notneeded[0]) - notneeded=tuple(notneeded) + def bwd_hook(type): + + def pre_hook(*notneeded): + self.overlap_handles['d'+type].wait() + self.sp_stream.wait_stream(torch.cuda.default_stream()) + tmp=self.overlap_handles['d'+type+'_grad'] + notneeded=list(notneeded) + notneeded[0]=list(notneeded[0]) + notneeded[0][0]=tmp.reshape(self.overlap_handles['d'+type+'_grad_shape']).contiguous() + notneeded[0]=tuple(notneeded[0]) + notneeded=tuple(notneeded) + return pre_hook - def k_hook(*notneeded): - self.bwd_all2all_handels['dk'].wait() - self.sp_stream.wait_stream(torch.cuda.default_stream()) - tmp=self.bwd_all2all_handels['dk_grad'] - notneeded=list(notneeded) - notneeded[0]=list(notneeded[0]) - notneeded[0][0]=tmp.reshape(self.bwd_all2all_handels['dk_grad_shape']).contiguous() - notneeded[0]=tuple(notneeded[0]) - notneeded=tuple(notneeded) + @@ -204,41 +183,34 @@ def k_hook(*notneeded): self.dafult_stream.wait_event(query.done_event) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_q,self.bwd_all2all_handels,'q') #[1,512,32,32] + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,async_bwd_comm_q,self.overlap_handles,'q') #[1,512,32,32] self.dafult_stream.wait_event(key.done_event) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,False,async_bwd_comm_k, self.bwd_all2all_handels,'k') #[1,512,32,32] + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,async_bwd_comm_k, self.overlap_handles,'k') #[1,512,32,32] self.dafult_stream.wait_stream(self.sp_stream) - value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False,False, self.bwd_all2all_handels,'v') #[1,512,32,32] + value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False, self.overlap_handles,'v') #[1,512,32,32] + # hard code currently if True: - async_bwd_comm_q=True - async_bwd_comm_k=True - #eval interval - fn_q = query.grad_fn.next_functions[0][0] - fn_q.register_prehook(q_hook) - fn_k = key.grad_fn.next_functions[0][0] - fn_k.register_prehook(k_hook) - #do dq qk k v - # def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None, fwd_async=False,bwd_async=False, handle=None,type=None) -> Tensor: - - self.bwd_all2all_handels['fwd_q'].wait() - self.bwd_all2all_handels['fwd_k'].wait() - # self.bwd_all2all_handels['fwd_q'].wait() + grad_fn_q = query.grad_fn.next_functions[0][0] + grad_fn_q.register_prehook(bwd_hook(type='q')) + grad_fn_k = key.grad_fn.next_functions[0][0] + grad_fn_k.register_prehook(bwd_hook(type='k')) + + + + self.overlap_handles['fwd_q'].wait() + self.overlap_handles['fwd_k'].wait() + # self.overlap_handles['fwd_q'].wait() #all2all ayns to k_dense_bwd wait #out shape : e.g., [s:h/p:] - # print(query_layer) #2,8, 2,4 sp=2 2gpus - # # - # print(key_layer) - # print(value_layer) #seq_len 16 , sp 2 , head_dim = 4, num_heads=4, hidding=16 - + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) #[8,512,4,32] bwd_o_async=False if self.sp_stream is not None: bwd_o_async=True - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,False,bwd_o_async) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,bwd_o_async) #out e.g., [s/p::h] return output - #o= self.dense(output) \ No newline at end of file From 65afd895055d723133436a83b381f975d25dbf0b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 14:32:31 +0000 Subject: [PATCH 09/16] refine code --- deepspeed/sequence/layer.py | 164 ++++++++++++++---------------------- 1 file changed, 65 insertions(+), 99 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index f00f9a051614..83a60f6c8382 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -13,13 +13,7 @@ from deepspeed.accelerator import get_accelerator -def wait_stream(stream): - get_accelerator().wait_stream(stream) - -def print0(msg): - if dist.get_rank()==0: - print(msg) -def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,handle=None,type=None): +def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size @@ -36,70 +30,67 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False,hand ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) - work= dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - + work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) + # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_idx < 2: output = output.transpose(0, 1).contiguous() - if async_op: - shape=( inp_shape[: gather_idx] + \ - [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]) - res=output.reshape(shape).contiguous() - if type=='dq' or type=='dk': - handle[type+'_grad']=output - handle[type+'_grad_shape']=shape - return res, work - #!! need to delete - c= output.reshape( - inp_shape[: gather_idx] + \ + res_shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() - return c,work + inp_shape[gather_idx + 1:]) + res = output.reshape(res_shape).contiguous() + if async_op: + if type in ('dq', 'dk'): + handle[type + '_grad'] = output + handle[type + '_grad_shape'] = res_shape + return res, work class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int, stream=None,bwd_async=False, handle=None,type=None,is_fwd=True) -> Tensor: - + def forward(ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + stream=None, + handle=None, + type=None, + is_fwd=True) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx - ctx.gather_idx = gather_idx - ctx.stream=stream - ctx.bwd_async=bwd_async - ctx.handle=handle - ctx.type=type - - if not is_fwd and type=='o': - assert stream!=None - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) + ctx.stream = stream + ctx.handle = handle + ctx.type = type + if not is_fwd and type == 'o': + assert stream != None + res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) - elif not is_fwd and (type=='q' or type=='k'): - type='d'+type - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type) - - handle[type]=work - elif is_fwd and (type=='q' or type=='k'): - type='fwd_'+type - - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,True,handle,type) - handle[type]=work + + elif not is_fwd and type in ('q', 'k'): + type = 'd' + type + res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + handle[type] = work + + elif is_fwd and type in ('q', 'k'): + type = 'fwd_' + type + res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + handle[type] = work + else: - res , work=single_all_to_all(input, scatter_idx, gather_idx, group,False) + res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) - return res - + return res @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - - - - return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream,False,ctx.handle,ctx.type,False), None,None,None,None,None,None,None) + return (None, + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle, + ctx.type, False), None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -126,17 +117,10 @@ def __init__( self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx - - - self.sp_stream=sp_stream - self.overlap_handles={} - self.overlap_handles['dq']=None - self.overlap_handles['dq_grad']=None - self.overlap_handles['dk']=None - self.overlap_handles['dk_grad']=None - self.dafult_stream=get_accelerator().default_stream() - - + self.sp_stream = sp_stream + #TODO: add class to clear logic for overlap + self.overlap_handles = {} + self.dafult_stream = get_accelerator().default_stream() def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor: """ forward @@ -150,67 +134,49 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens Returns: * output (Tensor): context output """ + # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] - - - #step1 get q ,k ,v outside out this function def bwd_hook(type): - - def pre_hook(*notneeded): - self.overlap_handles['d'+type].wait() - self.sp_stream.wait_stream(torch.cuda.default_stream()) - tmp=self.overlap_handles['d'+type+'_grad'] - notneeded=list(notneeded) - notneeded[0]=list(notneeded[0]) - notneeded[0][0]=tmp.reshape(self.overlap_handles['d'+type+'_grad_shape']).contiguous() - notneeded[0]=tuple(notneeded[0]) - notneeded=tuple(notneeded) - return pre_hook - - - - - - - async_bwd_comm_q=True - async_bwd_comm_k=True + def pre_hook_fun(grad): + self.overlap_handles['d' + type].wait() + self.sp_stream.wait_stream(torch.cuda.default_stream()) + all2all_output = self.overlap_handles['d' + type + '_grad'] + grad = list(grad) + grad[0] = all2all_output.reshape(self.overlap_handles['d' + type + '_grad_shape']).contiguous() + grad = tuple(grad) + return pre_hook_fun - self.dafult_stream.wait_event(query.done_event) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx,None,async_bwd_comm_q,self.overlap_handles,'q') #[1,512,32,32] + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, + self.overlap_handles, 'q') self.dafult_stream.wait_event(key.done_event) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx,None,async_bwd_comm_k, self.overlap_handles,'k') #[1,512,32,32] + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, + 'k') self.dafult_stream.wait_stream(self.sp_stream) - value_layer= _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx,None,False, self.overlap_handles,'v') #[1,512,32,32] - + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, + self.overlap_handles, 'v') + # hard code currently if True: grad_fn_q = query.grad_fn.next_functions[0][0] grad_fn_q.register_prehook(bwd_hook(type='q')) grad_fn_k = key.grad_fn.next_functions[0][0] grad_fn_k.register_prehook(bwd_hook(type='k')) - - self.overlap_handles['fwd_q'].wait() self.overlap_handles['fwd_k'].wait() - # self.overlap_handles['fwd_q'].wait() - #all2all ayns to k_dense_bwd wait + #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) #[8,512,4,32] - bwd_o_async=False - if self.sp_stream is not None: - bwd_o_async=True - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx,self.sp_stream,bwd_o_async) + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream) #out e.g., [s/p::h] return output - From 4b3518edbe7fe5addca03147d51fabca6799dca1 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 14:57:43 +0000 Subject: [PATCH 10/16] Revert "fix ds-sp grad scale for zero0" This reverts commit cb15ffa15304dfe2373f6bcc5353b397ecb93ed3. --- deepspeed/runtime/engine.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 95eb5d72e316..13f335cae6d5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2411,24 +2411,20 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): else: dp_group = groups._get_sequence_data_parallel_group() - dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple - self.allreduce_no_retain(dense_bucket, - dp_group=dp_group, - numel_per_bucket=elements_per_buffer, - dp_world_size=dp_world_size) + self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): # to maintain the gradients value unaffected by ep_size setting, # utilize dp_world_size for allreduce average - dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) / float(self.sequence_parallel_size) + dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) for ep_name, expert_grads_group in expert_grads.items(): ep_dp_group = groups._get_expert_data_parallel_group(ep_name) split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( @@ -2495,9 +2491,9 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / (dp_world_size)) + values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size))) else: - values.mul_(1. / (dp_world_size)) + values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) From 54b5ce3d817a6488e7669a84678ddaa8ed66c386 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 15:28:14 +0000 Subject: [PATCH 11/16] fix format --- deepspeed/sequence/layer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index c429e549e46b..b666d349a83b 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -36,7 +36,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han res_shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]) - transpose = True if scatter_idx<2 else False + transpose = True if scatter_idx < 2 else False if async_op: if type in ('dq', 'dk'): handle[type + '_grad'] = output @@ -75,17 +75,17 @@ def forward(ctx: Any, assert stream != None res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) - + elif not is_fwd and type in ('q', 'k'): type = 'd' + type res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) handle[type] = work - + elif is_fwd and type in ('q', 'k'): type = 'fwd_' + type res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) handle[type] = work - + else: res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) @@ -153,8 +153,8 @@ def pre_hook_fun(grad): self.sp_stream.wait_stream(torch.cuda.default_stream()) all2all_output = self.overlap_handles['d' + type + '_grad'] grad = list(grad) - if self.overlap_handles['transpose']==True: - all2all_output=all2all_output.transpose(0, 2).contiguous() + if self.overlap_handles['transpose'] == True: + all2all_output = all2all_output.transpose(0, 2).contiguous() grad[0] = all2all_output.reshape(self.overlap_handles['d' + type + '_grad_shape']).contiguous() grad = tuple(grad) @@ -182,7 +182,6 @@ def pre_hook_fun(grad): #out shape : e.g., [s:h/p:] - context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream) From c9f0c0adeb0e835174b7dcb2bdeff11ccf94d9b8 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Jun 2024 15:38:27 +0000 Subject: [PATCH 12/16] fix format --- deepspeed/sequence/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index b666d349a83b..6d7bd3646e14 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -150,7 +150,7 @@ def bwd_hook(type): def pre_hook_fun(grad): self.overlap_handles['d' + type].wait() - self.sp_stream.wait_stream(torch.cuda.default_stream()) + self.sp_stream.wait_stream(get_accelerator().default_stream()) all2all_output = self.overlap_handles['d' + type + '_grad'] grad = list(grad) if self.overlap_handles['transpose'] == True: From 0862aa373d3d8a6082323c606be1d89e51d986be Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 27 Jun 2024 03:29:48 +0000 Subject: [PATCH 13/16] refine code --- deepspeed/sequence/layer.py | 121 +++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 51 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 6d7bd3646e14..6c5a88ea914b 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - import torch from typing import Any, Tuple @@ -12,7 +11,14 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator - +def post_all2all(transpose, res_shape): + def post_func(input): + if transpose: + input=input.transpose(0, 2).contiguous() + input=input.reshape(res_shape) + return input + return post_func + def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) @@ -32,25 +38,25 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - res_shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]) transpose = True if scatter_idx < 2 else False + post_all2all_fun = post_all2all(transpose, res_shape) + if async_op: if type in ('dq', 'dk'): + handle[type+'_work']=work handle[type + '_grad'] = output - handle[type + '_grad_shape'] = res_shape - handle['transpose'] = transpose - # placeholder on the same device with the same shape. - res = output.reshape(res_shape) - return res, work - # if scattering the seq-dim, transpose the heads back to the original dimension - if transpose: - output = output.transpose(0, 2).contiguous() - res = output.reshape(res_shape).contiguous() - return res, work + handle[type+'_post_all2all_func'] = post_all2all_fun + return None + res=post_all2all_fun(output) + return res + + + + class _SeqAllToAll(torch.autograd.Function): @@ -70,24 +76,29 @@ def forward(ctx: Any, ctx.stream = stream ctx.handle = handle ctx.type = type - - if not is_fwd and type == 'o': - assert stream != None - res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) - get_accelerator().current_stream().wait_stream(ctx.stream) - - elif not is_fwd and type in ('q', 'k'): - type = 'd' + type - res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) - handle[type] = work - - elif is_fwd and type in ('q', 'k'): - type = 'fwd_' + type - res, work = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) - handle[type] = work - + if ctx.handle is None: + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + else: - res, work = single_all_to_all(input, scatter_idx, gather_idx, group, False) + # overlap communcation path + if not is_fwd and type == 'o': + # The computation of d o_weight can overlap with the communication of d o_input + assert stream != None + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + get_accelerator().current_stream().wait_stream(ctx.stream) + + elif not is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v + type = 'd' + type + res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + + elif is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv + type = 'fwd_' + type + res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) + + else: + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) return res @@ -123,11 +134,17 @@ def __init__( self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx + self.sp_overlap_comm = False + self.overlap_handles = None self.sp_stream = sp_stream - #TODO: add class to clear logic for overlap - self.overlap_handles = {} - self.dafult_stream = get_accelerator().default_stream() - + if sp_stream is not None: + self.overlap_handles = {} + self.sp_overlap_comm =True + self.dafult_stream = get_accelerator().default_stream() + + def layer_sync(self, layer): + if self.sp_overlap_comm and hasattr(layer, 'done_event'): + self.dafult_stream.wait_event(layer.done_event) def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: """ forward @@ -145,46 +162,48 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] - #step1 get q ,k ,v outside out this function - def bwd_hook(type): + def bwd_hook(layer_type): def pre_hook_fun(grad): - self.overlap_handles['d' + type].wait() + type='d' + layer_type + self.overlap_handles[type +'_work'].wait() self.sp_stream.wait_stream(get_accelerator().default_stream()) - all2all_output = self.overlap_handles['d' + type + '_grad'] + all2all_output = self.overlap_handles[type + '_grad'] grad = list(grad) - if self.overlap_handles['transpose'] == True: - all2all_output = all2all_output.transpose(0, 2).contiguous() - grad[0] = all2all_output.reshape(self.overlap_handles['d' + type + '_grad_shape']).contiguous() + grad[0]=self.overlap_handles[type + '_post_all2all_func'](all2all_output) grad = tuple(grad) return pre_hook_fun - self.dafult_stream.wait_event(query.done_event) + self.layer_sync(query) query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, self.overlap_handles, 'q') - self.dafult_stream.wait_event(key.done_event) + self.layer_sync(key) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, 'k') - self.dafult_stream.wait_stream(self.sp_stream) + if self.sp_overlap_comm: + self.dafult_stream.wait_stream(self.sp_stream) + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, self.overlap_handles, 'v') - # hard code currently - if True: + + if self.sp_overlap_comm: + # Register a hook to synchronize dq and dk after the all-to-all operation + # when the gradient data is used. Place this logic after the q, k, v + # all-to-all operation to improve interpreter speed by enabling an earlier + # call and launch of the forward all-to-all communication. grad_fn_q = query.grad_fn.next_functions[0][0] - grad_fn_q.register_prehook(bwd_hook(type='q')) + grad_fn_q.register_prehook(bwd_hook(layer_type='q')) grad_fn_k = key.grad_fn.next_functions[0][0] - grad_fn_k.register_prehook(bwd_hook(type='k')) + grad_fn_k.register_prehook(bwd_hook(layer_type='k')) - self.overlap_handles['fwd_q'].wait() - self.overlap_handles['fwd_k'].wait() #out shape : e.g., [s:h/p:] context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, self.overlap_handles,'o') #out e.g., [s/p::h] return output From 1c596dd6fa3845a2ca9c8034083891577d65cd11 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 5 Jul 2024 16:21:38 +0800 Subject: [PATCH 14/16] add register for v, ensuring they launch on a single thread. --- deepspeed/sequence/layer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 6c5a88ea914b..9af1c76c70eb 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -45,7 +45,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han post_all2all_fun = post_all2all(transpose, res_shape) if async_op: - if type in ('dq', 'dk'): + if type in ('dq', 'dk','dv'): handle[type+'_work']=work handle[type + '_grad'] = output handle[type+'_post_all2all_func'] = post_all2all_fun @@ -55,9 +55,6 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han return res - - - class _SeqAllToAll(torch.autograd.Function): @staticmethod @@ -82,12 +79,13 @@ def forward(ctx: Any, else: # overlap communcation path if not is_fwd and type == 'o': - # The computation of d o_weight can overlap with the communication of d o_input - assert stream != None + assert ctx.stream != None res = single_all_to_all(input, scatter_idx, gather_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) + del ctx.stream.activation_buffer_list + # The computation of d o_weight can overlap with the communication of d o_input - elif not is_fwd and type in ('q', 'k'): + elif not is_fwd and type in ('q', 'k','v'): # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'd' + type res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) @@ -189,14 +187,17 @@ def pre_hook_fun(grad): if self.sp_overlap_comm: - # Register a hook to synchronize dq and dk after the all-to-all operation - # when the gradient data is used. Place this logic after the q, k, v - # all-to-all operation to improve interpreter speed by enabling an earlier + # Register a hook to synchronize dq and dk after the all-to-all + # operation when the gradient data is used. + # Place this logic after the q, k, v all-to-all operation to + # improve interpreter speed to # call and launch of the forward all-to-all communication. grad_fn_q = query.grad_fn.next_functions[0][0] grad_fn_q.register_prehook(bwd_hook(layer_type='q')) grad_fn_k = key.grad_fn.next_functions[0][0] grad_fn_k.register_prehook(bwd_hook(layer_type='k')) + grad_fn_k = value.grad_fn.next_functions[0][0] + grad_fn_k.register_prehook(bwd_hook(layer_type='v')) #out shape : e.g., [s:h/p:] From 2fbbd5eb80580ba9eaa83ccdf478df50182d458b Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 10 Jul 2024 05:27:57 +0000 Subject: [PATCH 15/16] remove v --- deepspeed/sequence/layer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index 9af1c76c70eb..cba7cf913729 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -45,11 +45,11 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han post_all2all_fun = post_all2all(transpose, res_shape) if async_op: - if type in ('dq', 'dk','dv'): + if type in ('dq', 'dk'): handle[type+'_work']=work handle[type + '_grad'] = output handle[type+'_post_all2all_func'] = post_all2all_fun - return None + return output.view(res_shape) res=post_all2all_fun(output) return res @@ -85,7 +85,7 @@ def forward(ctx: Any, del ctx.stream.activation_buffer_list # The computation of d o_weight can overlap with the communication of d o_input - elif not is_fwd and type in ('q', 'k','v'): + elif not is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'd' + type res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) @@ -165,7 +165,7 @@ def bwd_hook(layer_type): def pre_hook_fun(grad): type='d' + layer_type self.overlap_handles[type +'_work'].wait() - self.sp_stream.wait_stream(get_accelerator().default_stream()) + self.sp_stream.wait_stream(self.dafult_stream) all2all_output = self.overlap_handles[type + '_grad'] grad = list(grad) grad[0]=self.overlap_handles[type + '_post_all2all_func'](all2all_output) @@ -185,7 +185,6 @@ def pre_hook_fun(grad): value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, self.overlap_handles, 'v') - if self.sp_overlap_comm: # Register a hook to synchronize dq and dk after the all-to-all # operation when the gradient data is used. @@ -196,8 +195,6 @@ def pre_hook_fun(grad): grad_fn_q.register_prehook(bwd_hook(layer_type='q')) grad_fn_k = key.grad_fn.next_functions[0][0] grad_fn_k.register_prehook(bwd_hook(layer_type='k')) - grad_fn_k = value.grad_fn.next_functions[0][0] - grad_fn_k.register_prehook(bwd_hook(layer_type='v')) #out shape : e.g., [s:h/p:] From 171eb67ea7f62014e6624b5db99f4ed9601925f7 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 10 Jul 2024 07:06:49 +0000 Subject: [PATCH 16/16] fix notes and format --- deepspeed/sequence/layer.py | 53 ++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index cba7cf913729..f17cfa883cc6 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -11,14 +11,18 @@ import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator + def post_all2all(transpose, res_shape): + def post_func(input): if transpose: - input=input.transpose(0, 2).contiguous() - input=input.reshape(res_shape) + input = input.transpose(0, 2).contiguous() + input = input.reshape(res_shape) return input + return post_func - + + def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) @@ -43,18 +47,18 @@ def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, han inp_shape[gather_idx + 1:]) transpose = True if scatter_idx < 2 else False post_all2all_fun = post_all2all(transpose, res_shape) - + if async_op: if type in ('dq', 'dk'): - handle[type+'_work']=work + handle[type + '_work'] = work handle[type + '_grad'] = output - handle[type+'_post_all2all_func'] = post_all2all_fun + handle[type + '_post_all2all_func'] = post_all2all_fun return output.view(res_shape) - res=post_all2all_fun(output) + res = post_all2all_fun(output) return res - + class _SeqAllToAll(torch.autograd.Function): @staticmethod @@ -75,9 +79,9 @@ def forward(ctx: Any, ctx.type = type if ctx.handle is None: res = single_all_to_all(input, scatter_idx, gather_idx, group, False) - + else: - # overlap communcation path + # overlap communication path if not is_fwd and type == 'o': assert ctx.stream != None res = single_all_to_all(input, scatter_idx, gather_idx, group, False) @@ -86,12 +90,12 @@ def forward(ctx: Any, # The computation of d o_weight can overlap with the communication of d o_input elif not is_fwd and type in ('q', 'k'): - # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v + # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv type = 'd' + type res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) - + elif is_fwd and type in ('q', 'k'): - # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv + # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'fwd_' + type res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) @@ -137,12 +141,13 @@ def __init__( self.sp_stream = sp_stream if sp_stream is not None: self.overlap_handles = {} - self.sp_overlap_comm =True + self.sp_overlap_comm = True self.dafult_stream = get_accelerator().default_stream() def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) + def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: """ forward @@ -163,12 +168,12 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg def bwd_hook(layer_type): def pre_hook_fun(grad): - type='d' + layer_type - self.overlap_handles[type +'_work'].wait() + type = 'd' + layer_type + self.overlap_handles[type + '_work'].wait() self.sp_stream.wait_stream(self.dafult_stream) all2all_output = self.overlap_handles[type + '_grad'] grad = list(grad) - grad[0]=self.overlap_handles[type + '_post_all2all_func'](all2all_output) + grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output) grad = tuple(grad) return pre_hook_fun @@ -181,27 +186,27 @@ def pre_hook_fun(grad): 'k') if self.sp_overlap_comm: self.dafult_stream.wait_stream(self.sp_stream) - + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, self.overlap_handles, 'v') if self.sp_overlap_comm: - # Register a hook to synchronize dq and dk after the all-to-all - # operation when the gradient data is used. - # Place this logic after the q, k, v all-to-all operation to - # improve interpreter speed to + # Register a hook to synchronize dq and dk after the all-to-all + # operation when the gradient data is used. + # Place this logic after the q, k, v all-to-all operation to + # improve interpreter speed to # call and launch of the forward all-to-all communication. grad_fn_q = query.grad_fn.next_functions[0][0] grad_fn_q.register_prehook(bwd_hook(layer_type='q')) grad_fn_k = key.grad_fn.next_functions[0][0] grad_fn_k.register_prehook(bwd_hook(layer_type='k')) - #out shape : e.g., [s:h/p:] context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, self.overlap_handles,'o') + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, + self.overlap_handles, 'o') #out e.g., [s/p::h] return output