diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 1cec106caec82..16ea7de2946bf 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -64,18 +64,6 @@ def __init__(self, layers, hcg, strategy): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - def _set_tensor_trainable(self, tensor): - if tensor is None: - return - - if isinstance(tensor, tuple): - for t in tensor: - if is_float_tensor(t): - t.stop_gradient = False - else: - if is_float_tensor(tensor): - tensor.stop_gradient = False - def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') @@ -117,7 +105,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): for step_id in range(startup_steps): input_tensor = p2p.recv_forward() - self._set_tensor_trainable(input_tensor) output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor) @@ -131,7 +118,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): for i in range(steady_steps): last_iter = (i == (steady_steps - 1)) - self._set_tensor_trainable(input_tensor) output_tensor = self._forward_step(input_tensor) output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index e533b2ef3f7a3..365652f96feb7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -15,6 +15,8 @@ import paddle from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger +import numpy as np +from paddle import _C_ops _hcg = None @@ -40,6 +42,7 @@ def __init__(self): self.recv_shape_message = None self.recv_dtype_message = None + self.recv_stop_gradient = None self.has_send_meta = False self.has_recv_meta = False @@ -57,7 +60,11 @@ def _recv_shape_dtype(self, group): # recv dtype dtype = paddle.to_tensor([0]) paddle.distributed.recv(dtype, src=0, group=group) - return shape.numpy().tolist(), dtype.item() + + # recv stop_gradient + stop_grad = paddle.to_tensor([0]) + paddle.distributed.recv(stop_grad, src=0, group=group) + return shape.numpy().tolist(), dtype.item(), stop_grad.item() def recv_meta(self, group): tensor_type = paddle.to_tensor([0]) @@ -65,9 +72,10 @@ def recv_meta(self, group): tensor_type = tensor_type.item() if tensor_type == 0: - shape, dtype = self._recv_shape_dtype(group) + shape, dtype, stop_grad = self._recv_shape_dtype(group) self.recv_shape_message = shape self.recv_dtype_message = dtype + self.recv_stop_gradient = bool(stop_grad) elif tensor_type == 1: num = paddle.to_tensor([0]) @@ -75,13 +83,16 @@ def recv_meta(self, group): num = num.item() shapes = [] dtypes = [] + stop_grads = [] for i in range(num): - shape, dtype = self._recv_shape_dtype(group) + shape, dtype, stop_grad = self._recv_shape_dtype(group) shapes.append(shape) dtypes.append(dtype) + stop_grads.append(bool(stop_grad)) self.recv_shape_message = tuple(shapes) self.recv_dtype_message = tuple(dtypes) + self.recv_stop_gradient = tuple(stop_grads) def _send_dims_shape_dtype(self, tensor, group): # send len(shape) @@ -96,6 +107,10 @@ def _send_dims_shape_dtype(self, tensor, group): dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) paddle.distributed.send(dtype, dst=1, group=group) + # send trainable + stop_grad = paddle.to_tensor(int(tensor.stop_gradient)) + paddle.distributed.send(stop_grad, dst=1, group=group) + def send_meta(self, tensor, group): if isinstance(tensor, paddle.Tensor): tensor_type = paddle.to_tensor([0]) @@ -129,6 +144,12 @@ def set_send_message(self, tensor): _send_recv_meta = SendRecvMeta() +def _is_valid_send_recv_partial(tensor, mp_degree): + tensor_numel = np.prod(tensor.shape) + assert tensor_numel != 0, "can't send/recv zero element" + return mp_degree > 1 and tensor_numel % mp_degree == 0 + + def send_partial(tensor, dst=0, nranks=1, @@ -138,9 +159,14 @@ def send_partial(tensor, if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - return paddle.fluid.core.ops.partial_send( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - dst, 'num', nranks, 'id', rank_id) + + if _is_valid_send_recv_partial(tensor, nranks): + return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'peer', dst, 'num', + nranks, 'id', rank_id) + else: + return paddle.distributed.send( + tensor, dst=dst, group=group, use_calc_stream=use_calc_stream) def recv_partial(tensor, @@ -153,10 +179,14 @@ def recv_partial(tensor, return ring_id = 0 if group is None else group.id - paddle.fluid.core.ops.partial_recv( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', - tensor.shape) + if _is_valid_send_recv_partial(tensor, nranks): + _C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'peer', src, 'num', nranks, + 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', + tensor.shape) + else: + paddle.distributed.recv( + tensor, src=src, group=group, use_calc_stream=use_calc_stream) def allgather_partial(tensor, @@ -164,15 +194,15 @@ def allgather_partial(tensor, rank_id=0, group=None, use_calc_stream=True): - if nranks == 1: + if not _is_valid_send_recv_partial(tensor, nranks): return tensor if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - return paddle.fluid.core.ops.partial_allgather_( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'nranks', nranks, 'rank', rank_id) + return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks, + 'rank', rank_id) def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): @@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): # send / recv message recv_shape_msg = _send_recv_meta.recv_shape_message recv_dtype_msg = _send_recv_meta.recv_dtype_message + recv_stop_gradient = _send_recv_meta.recv_stop_gradient + send_shape_msg = _send_recv_meta.send_shape_message send_dtype_msg = _send_recv_meta.send_dtype_message @@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(recv_shape_msg, tuple): tensor_recv_prev = [] for idx, shape in enumerate(recv_shape_msg): - tensor_recv_prev.append( - paddle.empty( - shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))) + tmp = paddle.empty( + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])) + tmp.stop_gradient = recv_stop_gradient[idx] + tensor_recv_prev.append(tmp) tensor_recv_prev = tuple(tensor_recv_prev) else: + tensor_recv_prev = paddle.empty( shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) + tensor_recv_prev.stop_gradient = recv_stop_gradient if recv_next: if isinstance(send_shape_msg, tuple): @@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_prev: paddle.distributed.wait(d, use_calc_stream=True) send_partial( - d, + d.detach(), dst=0, nranks=mp_degree, rank_id=mp_rank, @@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) send_partial( - tensor_send_prev, + tensor_send_prev.detach(), dst=0, nranks=mp_degree, rank_id=mp_rank, @@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: recv_partial( - d, + d.detach(), src=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True) allgather_partial( - d, + d.detach(), nranks=mp_degree, rank_id=mp_rank, group=mp_group, use_calc_stream=True) else: recv_partial( - tensor_recv_prev, + tensor_recv_prev.detach(), src=0, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True) allgather_partial( - tensor_recv_prev, + tensor_recv_prev.detach(), nranks=mp_degree, rank_id=mp_rank, group=mp_group, @@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) send_partial( - d, + d.detach(), dst=1, nranks=mp_degree, rank_id=mp_rank, @@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: paddle.distributed.wait(tensor_send_next, use_calc_stream=True) send_partial( - tensor_send_next, + tensor_send_next.detach(), dst=1, nranks=mp_degree, rank_id=mp_rank, @@ -294,14 +329,14 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: recv_partial( - d, + d.detach(), src=1, nranks=mp_degree, rank_id=mp_rank, group=_hcg.recv_next_group, use_calc_stream=True) allgather_partial( - d, + d.detach(), nranks=mp_degree, rank_id=mp_rank, group=mp_group, @@ -309,7 +344,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): else: recv_partial( - tensor_recv_next, + tensor_recv_next.detach(), src=1, nranks=mp_degree, rank_id=mp_rank, @@ -317,7 +352,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): use_calc_stream=True) allgather_partial( - tensor_recv_next, + tensor_recv_next.detach(), nranks=mp_degree, rank_id=mp_rank, group=mp_group, diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index b336330836a66..62b1a8b1da679 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -54,13 +54,17 @@ def forward(self, x): attention_mask = paddle.tensor.triu( (paddle.ones( (length, length), dtype="float32") * -1e9), 1) - attention_mask.stop_gradient = True + + no_used = paddle.ones((3, 3), dtype="int32") + w_emb = self.word_embeddings(x) p_emb = self.position_embeddings(x) w_emb = w_emb + p_emb + attention_mask.stop_gradient = True + no_used.stop_gradient = True # need to fix bug of backward() - return w_emb, attention_mask + return w_emb, attention_mask, no_used, p_emb class TransformerNet(Layer): @@ -99,12 +103,12 @@ def forward(self, x): class TransformerNetPipe(TransformerNet): def forward(self, args): - x, mask = args[0], args[1] + x, mask, no_used, p_emb = args[0], args[1], args[2], args[3] output = super().forward(x, mask) - output = output + output = output + p_emb mask.stop_gradient = True - return output, mask + return output, mask, no_used, p_emb class CriterionPipe(Layer): @@ -175,6 +179,8 @@ def test_pp_model(self): loss = model.train_batch([x, x], optimizer, scheduler) # TODO(shenliang03) add utest for loss + print("loss: ", loss) + if __name__ == "__main__": unittest.main()