From 9cc3f69f9e159ebef64d82b570d915e0da803362 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 11 Oct 2022 14:21:40 +0800 Subject: [PATCH] Cherry pick for dygraph pp (#46876) * bug fix for virtual pipeline parallel (#45922) * dont wait for send op under dygraph pp (#46209) * [interleave pp] sync recv for 1f1b (#46399) * [dygraph pp] all sync for allgather partial (#46483) --- .../parallel_layers/pp_layers.py | 6 +- .../fleet/meta_parallel/pipeline_parallel.py | 5 +- .../pp_utils/p2p_communication.py | 271 ++++++++++-------- 3 files changed, 155 insertions(+), 127 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 926e6aab81e56..5824fbe6df2d6 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -378,7 +378,7 @@ def get_stage_from_index(self, layer_idx): for virtual_pp_rank in range(self._num_virtual_pipeline_stages): # Mapping the virtual pipeline stage to the real pipeline stage. # start_idx marks the start of a new virtual pp stage. - start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages + start_idx = virtual_pp_rank * self._num_stages for stage in range(self._num_stages): # stage mark the real pp stage if self.segment_parts[start_idx + @@ -484,7 +484,7 @@ def _segment_network_for_interleave(self, seg_method): ", ".join(str(arg) for arg in self.segment_parts)) for i in range(self._stage_id, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers. # Layers [0, 1], [4, 5] will be assigned to the first real pp stage. # Layers [2, 3], [6, 7] will be assigned to the second real pp stage. @@ -529,7 +529,7 @@ def _print_segmentation_for_debug(self): stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format( stage) for i in range(stage, self._total_stages_with_virtual_stages, - self._num_virtual_pipeline_stages): + self._num_stages): stage_to_virtual_stage_info += " {},".format(i) logger.info(stage_to_virtual_stage_info) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 02a1b421526df..56429b748064d 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -526,7 +526,7 @@ def interleave_pipeline(self, self.set_virtual_pipeline_rank(0) self.input_tensors[0].append( - p2p.recv_forward(self.is_pipeline_first_stage())) + p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)) # run startup steps for micro_step in range(startup_steps): @@ -647,7 +647,8 @@ def interleave_pipeline(self, if not forward_only: if all_startup_steps: self.output_tensor_grads[self.num_model_chunks - 1].append( - p2p.recv_backward(self.is_pipeline_last_stage())) + p2p.recv_backward(self.is_pipeline_last_stage(), + sync_recv=False)) for micro_step in range(steady_steps, num_steps): # cooldown loop diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 3b4094f047552..e2ca6f8d2a034 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -165,17 +165,15 @@ def _is_valid_send_recv_partial(tensor, mp_degree): def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id): - dst_rank_in_group = dst if group is None else group.get_group_rank(dst) if _in_legacy_dygraph(): return _legacy_C_ops.partial_send(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', dst_rank_in_group, 'num', - nranks, 'id', rank_id) + 'peer', dst, 'num', nranks, 'id', + rank_id) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.send_partial(tensor, dst_rank_in_group, - nranks, rank_id) + return group.process_group.send_partial(tensor, dst, nranks, rank_id) def send_partial(tensor, @@ -189,13 +187,12 @@ def send_partial(tensor, return ring_id = 0 if group is None else group.id - dst_rank = _hcg._get_p2p_next_rank( - ) if dst == 1 else _hcg._get_p2p_prev_rank() - if _is_valid_send_recv_partial(tensor, nranks): - return _partial_send_op(tensor, group, use_calc_stream, ring_id, - dst_rank, nranks, rank_id) + return _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, + nranks, rank_id) else: + dst_rank = _hcg._get_p2p_next_rank( + ) if dst == 1 else _hcg._get_p2p_prev_rank() if _in_legacy_dygraph(): send_op = paddle.distributed.send elif in_dygraph_mode(): @@ -205,19 +202,22 @@ def send_partial(tensor, def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, rank_id): - src_rank_in_group = src if group is None else group.get_group_rank(src) if _in_legacy_dygraph(): + assert use_calc_stream return _legacy_C_ops.partial_recv(tensor.detach(), 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'peer', src_rank_in_group, 'num', - nranks, 'id', rank_id, 'dtype', - tensor.dtype, 'out_shape', - tensor.shape) + 'peer', src, 'num', nranks, 'id', + rank_id, 'dtype', tensor.dtype, + 'out_shape', tensor.shape) elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.recv_partial(tensor, src_rank_in_group, - nranks, rank_id) + task = group.process_group.recv_partial(tensor, src, nranks, rank_id) + if use_calc_stream: + task.wait() + return None + else: + return task def recv_partial(tensor, @@ -231,14 +231,13 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id - src_rank = _hcg._get_p2p_prev_rank( - ) if src == 0 else _hcg._get_p2p_next_rank() - if _is_valid_send_recv_partial(tensor, nranks): - return _partial_recv_op(tensor, group, use_calc_stream, ring_id, - src_rank, nranks, rank_id) + return _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, + nranks, rank_id) else: - if _in_legacy_dygraph(): + src_rank = _hcg._get_p2p_prev_rank( + ) if src == 0 else _hcg._get_p2p_next_rank() + if _in_legacy_dygraph() or use_calc_stream: recv_op = paddle.distributed.recv elif in_dygraph_mode(): recv_op = paddle.distributed.irecv @@ -256,8 +255,13 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.all_gather_partial(tensor, tensor, nranks, + task = group.process_group.all_gather_partial(tensor, tensor, nranks, rank_id) + if use_calc_stream: + task.wait() + return None + else: + return task def allgather_partial(tensor, @@ -266,16 +270,20 @@ def allgather_partial(tensor, group=None, use_calc_stream=True): if not _is_valid_send_recv_partial(tensor, nranks): - return None + return tensor if group is not None and not group.is_member(): - return None + return ring_id = 0 if group is None else group.id return _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, rank_id) -def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): +def _p2p_helper(tensor_send_next, + tensor_send_prev, + recv_prev, + recv_next, + sync_recv=True): global _hcg tensor_recv_prev = None @@ -327,126 +335,140 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: - if _in_legacy_dygraph(): - paddle.distributed.wait(d, use_calc_stream=True) - tasks.append( - send_partial(d, - dst=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_prev_group, - use_calc_stream=False)) - else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - tasks.append( - send_partial(tensor_send_prev, + paddle.distributed.wait(d, use_calc_stream=True) + send_partial(d, dst=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_prev_group, - use_calc_stream=False)) + use_calc_stream=False) + else: + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + send_partial(tensor_send_prev, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False) if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - tasks.append( - recv_partial(d, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=True)) + task = recv_partial(d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) else: - tasks.append( - recv_partial(tensor_recv_prev, - src=0, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_prev_group, - use_calc_stream=True)) + task = recv_partial(tensor_recv_prev, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(tensor_recv_prev, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): for d in tensor_send_next: - if _in_legacy_dygraph(): - paddle.distributed.wait(d, use_calc_stream=True) - tasks.append( - send_partial(d, - dst=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.send_next_group, - use_calc_stream=False)) - else: - if _in_legacy_dygraph(): - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - tasks.append( - send_partial(tensor_send_next, + paddle.distributed.wait(d, use_calc_stream=True) + send_partial(d, dst=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.send_next_group, - use_calc_stream=False)) - - if tensor_recv_next is not None: - if isinstance(tensor_recv_next, tuple): - for d in tensor_recv_next: - tasks.append( - recv_partial(d, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=True)) - + use_calc_stream=False) else: - tasks.append( - recv_partial(tensor_recv_next, - src=1, - nranks=mp_degree, - rank_id=mp_rank, - group=_hcg.recv_next_group, - use_calc_stream=True)) - - if in_dygraph_mode(): - # wait isend/irecv tasks in eager dygraph mode with new comm library - for task in tasks: - assert task is not None - task.wait() + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + send_partial(tensor_send_next, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False) - tensors_for_all_gather = [] - if tensor_recv_prev is not None: - if isinstance(tensor_recv_prev, tuple): - for d in tensor_recv_prev: - tensors_for_all_gather.append(d) - else: - tensors_for_all_gather.append(tensor_recv_prev) if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: - tensors_for_all_gather.append(d) - else: - tensors_for_all_gather.append(tensor_recv_next) + task = recv_partial(d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) - tasks = [] - for tensor in tensors_for_all_gather: - tasks.append( + else: + task = recv_partial(tensor_recv_next, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=sync_recv) + if sync_recv: + allgather_partial(tensor_recv_next, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + tasks.append(task) + + if not sync_recv: + if in_dygraph_mode(): + # wait irecv tasks in eager dygraph mode with new comm library + for task in tasks: + assert task is not None + task.wait() + + tensors_for_all_gather = [] + if tensor_recv_prev is not None: + if isinstance(tensor_recv_prev, tuple): + for d in tensor_recv_prev: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_prev) + if tensor_recv_next is not None: + if isinstance(tensor_recv_next, tuple): + for d in tensor_recv_next: + tensors_for_all_gather.append(d) + else: + tensors_for_all_gather.append(tensor_recv_next) + + for tensor in tensors_for_all_gather: allgather_partial(tensor, nranks=mp_degree, rank_id=mp_rank, group=mp_group, - use_calc_stream=True)) - - for task in tasks: - # wait partial all gather tasks - if task is not None: - task.wait() + use_calc_stream=True) return tensor_recv_prev, tensor_recv_next -def recv_forward(pp_first_stage): +def recv_forward(pp_first_stage, sync_recv=True): if pp_first_stage: input_tensor = None else: @@ -457,18 +479,20 @@ def recv_forward(pp_first_stage): input_tensor, _ = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=True, - recv_next=False) + recv_next=False, + sync_recv=sync_recv) return input_tensor -def recv_backward(pp_last_stage): +def recv_backward(pp_last_stage, sync_recv=True): if pp_last_stage: output_tensor_grad = None else: _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=None, recv_prev=False, - recv_next=True) + recv_next=True, + sync_recv=sync_recv) return output_tensor_grad @@ -530,7 +554,8 @@ def send_forward_backward_recv_forward_backward(output_tensor, tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return input_tensor, output_tensor_grad @@ -547,7 +572,8 @@ def send_forward_recv_forward(output_tensor, recv_prev): input_tensor, _ = _p2p_helper(tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, - recv_next=False) + recv_next=False, + sync_recv=False) return input_tensor @@ -556,5 +582,6 @@ def send_backward_recv_backward(input_tensor_grad, recv_next): _, output_tensor_grad = _p2p_helper(tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, - recv_next=recv_next) + recv_next=recv_next, + sync_recv=False) return output_tensor_grad