Skip to content

Commit

Permalink
optimize pipeline performance with recompute and amp, test=allcase (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored Aug 5, 2021
1 parent 1d7b75d commit 911c859
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 11 deletions.
7 changes: 7 additions & 0 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,13 @@ def _append_backward_ops_with_checkpoints_(
for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])

# Set device for grad_op according to forward Op
if op_desc.has_attr(device_attr_name):
op_device = op_desc.attr(device_attr_name)
for g_op_desc in grad_op_desc:
g_op_desc._set_attr(device_attr_name, op_device)

for key in var_name_dict:
_rename_arg_(grad_op_desc, key, var_name_dict[key])
grad_op_descs.extend(grad_op_desc)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def _update_list(self):
'c_identity',
'c_concat',
'c_allreduce_sum',
'concat',
'split',
}

# The set of ops that don't support fp16 calculation
Expand Down
23 changes: 22 additions & 1 deletion python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,27 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
op_device = op.attr('op_device')
# NOTE(wangxi): optimize for pipeline, reduce one send.
# if in_var is stop_gradient and prev_op device is `all`,
# set cast_op device to `all`, can reduce send cast_var.
# TODO: need remove this after we unified the dynamic
# and static pipeline interface.
if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient:
prev_op = None
if in_var.op is op:
prev_op = find_true_prev_op(block.ops, op,
in_var_name)
elif in_var.op is not None:
prev_op = in_var.op

prev_op_device = None
if prev_op is not None:
prev_op_device = prev_op.attr('op_device')

if prev_op_device is not None and 'all' in prev_op_device:
op_device = prev_op_device

out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
Expand All @@ -124,7 +145,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
"op_device": op_device
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

import unittest
import paddle
import paddle.fluid as fluid
import paddle.static as static
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import os

paddle.enable_static()
Expand All @@ -25,26 +29,34 @@ def setUp(self):
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"

def test_pipeline_optimizer(self):
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
with paddle.fluid.device_guard("gpu:0"):
def net(self):
with static.device_guard("gpu:0"):
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
input_z = paddle.fluid.layers.data(
name="z", shape=[1], dtype="float32")
with static.device_guard("gpu:all"):
input_z = input_z * 1.0
input_z.stop_gradient = True
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_1 = fc_1 * input_z

with paddle.fluid.device_guard("gpu:1"):
with static.device_guard("gpu:1"):
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
fc_2 = fc_2 * input_z
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
return avg_cost

def test_pipeline_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)

strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.pipeline = True
Expand All @@ -53,9 +65,43 @@ def test_pipeline_optimizer(self):
'accumulate_steps': 2
}

optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
train_prog, startup_prog = static.Program(), static.Program()
with static.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
avg_cost = self.net()

optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)

def test_pipeline_amp_optimizer(self):
""" test pipeline&amp with device:all """
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)

strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.amp = True
strategy.pipeline = True
strategy.pipeline_configs = {
'micro_batch_size': 1,
'accumulate_steps': 2
}

train_prog, startup_prog = static.Program(), static.Program()
with static.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
avg_cost = self.net()

optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)

ops = train_prog._pipeline_opt['section_program'].global_block().ops
ops = [op.type for op in ops]
self.assertEqual(ops.count('send_v2'), 1)
self.assertEqual(ops.count('recv_v2'), 1)


if __name__ == "__main__":
Expand Down

0 comments on commit 911c859

Please sign in to comment.