Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16 param to the optimizer #34965

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message ShardingConfig {
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
}

message HybridConfig {
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ def sharding_configs(self):
pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on.
This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False.

optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it
will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large.


Examples:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name

__all__ = []
Expand Down Expand Up @@ -84,7 +84,7 @@ def _create_offload_var(self, var_name, offload_var_name, blocks):
dtype=var.dtype,
persistable=True)

def offload_fp32param(self, block, startup_block):
def offload_fp32param(self, block, startup_block, offload=True):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
Expand Down Expand Up @@ -113,19 +113,20 @@ def remove_param(input_name):

# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
if is_update_op(op):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx

# step2: remove param which can't offload
# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue

# param is real used by fp32 op
# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
Expand Down Expand Up @@ -154,17 +155,19 @@ def remove_param(input_name):
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])

# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)

assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
Expand All @@ -173,8 +176,9 @@ def remove_param(input_name):
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])

# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue

# step3.4: remove cast op
Expand Down Expand Up @@ -206,9 +210,10 @@ def remove_param(input_name):

if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])

Expand All @@ -217,6 +222,19 @@ def remove_param(input_name):
block._sync_with_cpp()
startup_block._sync_with_cpp()

def cast_fp32param_in_optimize(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)

(pout,) = adam(p)
(p_fp16) = cast(p)
"""
self.offload_fp32param(block, startup_block, offload=False)

def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,14 @@ def _apply_optimize_offload_pass(self):
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
elif sharding_configs['optimize_cast']:
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper()
offload_helper.cast_fp32param_in_optimize(main_block, startup_block)

def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
Expand Down Expand Up @@ -444,6 +451,7 @@ def minimize_impl(self,
# loss div dp_degree
self._insert_loss_grad_scale_op()

# apply optimize offload or optimize cast
self._apply_optimize_offload_pass()

# step6: (optional) sharding gradient merge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,197 @@ def test_hybrid_with_sharding_pp_amp_fp16allreduce_in_optimize(self):

self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002'])

def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_cast": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']

startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops

# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]

# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random',
'cast', 'fill_constant', 'cast', 'uniform_random', 'cast',
'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast',
'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum',
'cast'
])

# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)

# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)

# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])

# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])

def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.amp = True
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"optimize_offload": True,
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
strategy.fp16_allreduce = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']

startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops

# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]

# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast',
'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant',
'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy',
'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast',
'memcpy', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'scale', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'fill_constant', 'cast', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum',
'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast',
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy',
'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast',
'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum',
'memcpy', 'momentum', 'cast', 'memcpy'
])

# amp check_finite_and_unscale, allreduce(pp)
self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1)

# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.pp_pair_ring_id, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)

# check correctness of pp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_0":
pp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003'])

# check correctness of dp group
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "comm_id_3":
dp_group_waiting_ports = op.desc.attr("other_endpoints")

self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])


if __name__ == "__main__":
unittest.main()