diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 365652f96feb7..5e2f4ba721931 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -257,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_prev: paddle.distributed.wait(d, use_calc_stream=True) send_partial( - d.detach(), + d, dst=0, nranks=mp_degree, rank_id=mp_rank, @@ -266,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) send_partial( - tensor_send_prev.detach(), + tensor_send_prev, dst=0, nranks=mp_degree, rank_id=mp_rank, @@ -277,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: recv_partial( - d.detach(), + d, src=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True) allgather_partial( - d.detach(), + d, nranks=mp_degree, rank_id=mp_rank, group=mp_group, use_calc_stream=True) else: recv_partial( - tensor_recv_prev.detach(), + tensor_recv_prev, src=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True) allgather_partial( - tensor_recv_prev.detach(), + tensor_recv_prev, nranks=mp_degree, rank_id=mp_rank, group=mp_group, @@ -309,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) send_partial( - d.detach(), + d, dst=1, nranks=mp_degree, rank_id=mp_rank, @@ -318,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: paddle.distributed.wait(tensor_send_next, use_calc_stream=True) send_partial( - tensor_send_next.detach(), + tensor_send_next, dst=1, nranks=mp_degree, rank_id=mp_rank, @@ -329,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: recv_partial( - d.detach(), + d, src=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, use_calc_stream=True) allgather_partial( - d.detach(), + d, nranks=mp_degree, rank_id=mp_rank, group=mp_group, @@ -344,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: recv_partial( - tensor_recv_next.detach(), + tensor_recv_next, src=1, nranks=mp_degree, rank_id=mp_rank, @@ -352,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): use_calc_stream=True) allgather_partial( - tensor_recv_next.detach(), + tensor_recv_next, nranks=mp_degree, rank_id=mp_rank, group=mp_group,