Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Fix bug of p2p for partial_send/recv #34615

Merged
merged 2 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,6 @@ def __init__(self, layers, hcg, strategy):
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)

def _set_tensor_trainable(self, tensor):
if tensor is None:
return

if isinstance(tensor, tuple):
for t in tensor:
if is_float_tensor(t):
t.stop_gradient = False
else:
if is_float_tensor(tensor):
tensor.stop_gradient = False

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.')
Expand Down Expand Up @@ -117,7 +105,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):

for step_id in range(startup_steps):
input_tensor = p2p.recv_forward()
self._set_tensor_trainable(input_tensor)

output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor)
Expand All @@ -131,7 +118,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
for i in range(steady_steps):
last_iter = (i == (steady_steps - 1))

self._set_tensor_trainable(input_tensor)
output_tensor = self._forward_step(input_tensor)

output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger
import numpy as np
from paddle import _C_ops

_hcg = None

Expand All @@ -40,6 +42,7 @@ def __init__(self):

self.recv_shape_message = None
self.recv_dtype_message = None
self.recv_stop_gradient = None

self.has_send_meta = False
self.has_recv_meta = False
Expand All @@ -57,31 +60,39 @@ def _recv_shape_dtype(self, group):
# recv dtype
dtype = paddle.to_tensor([0])
paddle.distributed.recv(dtype, src=0, group=group)
return shape.numpy().tolist(), dtype.item()

# recv stop_gradient
stop_grad = paddle.to_tensor([0])
paddle.distributed.recv(stop_grad, src=0, group=group)
return shape.numpy().tolist(), dtype.item(), stop_grad.item()

def recv_meta(self, group):
tensor_type = paddle.to_tensor([0])
paddle.distributed.recv(tensor_type, src=0, group=group)
tensor_type = tensor_type.item()

if tensor_type == 0:
shape, dtype = self._recv_shape_dtype(group)
shape, dtype, stop_grad = self._recv_shape_dtype(group)
self.recv_shape_message = shape
self.recv_dtype_message = dtype
self.recv_stop_gradient = bool(stop_grad)

elif tensor_type == 1:
num = paddle.to_tensor([0])
paddle.distributed.recv(num, src=0, group=group)
num = num.item()
shapes = []
dtypes = []
stop_grads = []
for i in range(num):
shape, dtype = self._recv_shape_dtype(group)
shape, dtype, stop_grad = self._recv_shape_dtype(group)
shapes.append(shape)
dtypes.append(dtype)
stop_grads.append(bool(stop_grad))

self.recv_shape_message = tuple(shapes)
self.recv_dtype_message = tuple(dtypes)
self.recv_stop_gradient = tuple(stop_grads)

def _send_dims_shape_dtype(self, tensor, group):
# send len(shape)
Expand All @@ -96,6 +107,10 @@ def _send_dims_shape_dtype(self, tensor, group):
dtype = paddle.to_tensor(paddle_2_number(tensor.dtype))
paddle.distributed.send(dtype, dst=1, group=group)

# send trainable
stop_grad = paddle.to_tensor(int(tensor.stop_gradient))
paddle.distributed.send(stop_grad, dst=1, group=group)

def send_meta(self, tensor, group):
if isinstance(tensor, paddle.Tensor):
tensor_type = paddle.to_tensor([0])
Expand Down Expand Up @@ -129,6 +144,12 @@ def set_send_message(self, tensor):
_send_recv_meta = SendRecvMeta()


def _is_valid_send_recv_partial(tensor, mp_degree):
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0


def send_partial(tensor,
dst=0,
nranks=1,
Expand All @@ -138,9 +159,14 @@ def send_partial(tensor,
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
return paddle.fluid.core.ops.partial_send(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
dst, 'num', nranks, 'id', rank_id)

if _is_valid_send_recv_partial(tensor, nranks):
return _C_ops.partial_send(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst, 'num',
nranks, 'id', rank_id)
else:
return paddle.distributed.send(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以设置nranks=1, rank_id=0,还是用partial_send就行

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,两个都可以,但是都判断了,还是调用原生的send,我觉得。😊

tensor, dst=dst, group=group, use_calc_stream=use_calc_stream)


def recv_partial(tensor,
Expand All @@ -153,26 +179,30 @@ def recv_partial(tensor,
return
ring_id = 0 if group is None else group.id

paddle.fluid.core.ops.partial_recv(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer',
src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
if _is_valid_send_recv_partial(tensor, nranks):
_C_ops.partial_recv(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'num', nranks,
'id', rank_id, 'dtype', tensor.dtype, 'out_shape',
tensor.shape)
else:
paddle.distributed.recv(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

tensor, src=src, group=group, use_calc_stream=use_calc_stream)


def allgather_partial(tensor,
nranks=1,
rank_id=0,
group=None,
use_calc_stream=True):
if nranks == 1:
if not _is_valid_send_recv_partial(tensor, nranks):
return tensor
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id

return paddle.fluid.core.ops.partial_allgather_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id,
'nranks', nranks, 'rank', rank_id)
return _C_ops.partial_allgather_(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks,
'rank', rank_id)


def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
Expand All @@ -184,6 +214,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
# send / recv message
recv_shape_msg = _send_recv_meta.recv_shape_message
recv_dtype_msg = _send_recv_meta.recv_dtype_message
recv_stop_gradient = _send_recv_meta.recv_stop_gradient

send_shape_msg = _send_recv_meta.send_shape_message
send_dtype_msg = _send_recv_meta.send_dtype_message

Expand All @@ -196,13 +228,16 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(recv_shape_msg, tuple):
tensor_recv_prev = []
for idx, shape in enumerate(recv_shape_msg):
tensor_recv_prev.append(
paddle.empty(
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])))
tmp = paddle.empty(
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))
tmp.stop_gradient = recv_stop_gradient[idx]
tensor_recv_prev.append(tmp)
tensor_recv_prev = tuple(tensor_recv_prev)
else:

tensor_recv_prev = paddle.empty(
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg))
tensor_recv_prev.stop_gradient = recv_stop_gradient

if recv_next:
if isinstance(send_shape_msg, tuple):
Expand All @@ -222,7 +257,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_prev:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
d.detach(),
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
Expand All @@ -231,7 +266,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_prev, use_calc_stream=True)
send_partial(
tensor_send_prev,
tensor_send_prev.detach(),
dst=0,
nranks=mp_degree,
rank_id=mp_rank,
Expand All @@ -242,28 +277,28 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_prev, tuple):
for d in tensor_recv_prev:
recv_partial(
d,
d.detach(),
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
d,
d.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)
else:
recv_partial(
tensor_recv_prev,
tensor_recv_prev.detach(),
src=0,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_prev_group,
use_calc_stream=True)
allgather_partial(
tensor_recv_prev,
tensor_recv_prev.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
Expand All @@ -274,7 +309,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
for d in tensor_send_next:
paddle.distributed.wait(d, use_calc_stream=True)
send_partial(
d,
d.detach(),
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
Expand All @@ -283,7 +318,7 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
else:
paddle.distributed.wait(tensor_send_next, use_calc_stream=True)
send_partial(
tensor_send_next,
tensor_send_next.detach(),
dst=1,
nranks=mp_degree,
rank_id=mp_rank,
Expand All @@ -294,30 +329,30 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
if isinstance(tensor_recv_next, tuple):
for d in tensor_recv_next:
recv_partial(
d,
d.detach(),
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)
allgather_partial(
d,
d.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
use_calc_stream=True)

else:
recv_partial(
tensor_recv_next,
tensor_recv_next.detach(),
src=1,
nranks=mp_degree,
rank_id=mp_rank,
group=_hcg.recv_next_group,
use_calc_stream=True)

allgather_partial(
tensor_recv_next,
tensor_recv_next.detach(),
nranks=mp_degree,
rank_id=mp_rank,
group=mp_group,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ def forward(self, x):
attention_mask = paddle.tensor.triu(
(paddle.ones(
(length, length), dtype="float32") * -1e9), 1)
attention_mask.stop_gradient = True

no_used = paddle.ones((3, 3), dtype="int32")

w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb

attention_mask.stop_gradient = True
no_used.stop_gradient = True
# need to fix bug of backward()
return w_emb, attention_mask
return w_emb, attention_mask, no_used, p_emb


class TransformerNet(Layer):
Expand Down Expand Up @@ -99,12 +103,12 @@ def forward(self, x):

class TransformerNetPipe(TransformerNet):
def forward(self, args):
x, mask = args[0], args[1]
x, mask, no_used, p_emb = args[0], args[1], args[2], args[3]

output = super().forward(x, mask)
output = output
output = output + p_emb
mask.stop_gradient = True
return output, mask
return output, mask, no_used, p_emb


class CriterionPipe(Layer):
Expand Down Expand Up @@ -175,6 +179,8 @@ def test_pp_model(self):
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss

print("loss: ", loss)


if __name__ == "__main__":
unittest.main()