From 4171aa6f75fffa9d492b6b70d1ace7ac411aa69d Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 17 Aug 2021 14:18:14 +0800 Subject: [PATCH 1/5] remove fp32param cast in hybird --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 3 ++ .../sharding/offload_helper.py | 48 +++++++++++++------ .../meta_optimizers/sharding_optimizer.py | 8 ++++ 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 1de6d26d05b9e..b4a3caff1675b 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -42,6 +42,7 @@ message ShardingConfig { optional bool optimize_offload = 9 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional int32 pp_degree = 11 [ default = 1 ]; + optional bool remove_param_cast = 12 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 051f6b11c2609..f72a9df1653a2 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -888,6 +888,9 @@ def sharding_configs(self): pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. + remove_param_cast(bool, optional): [Hybrid parallelism ONLY] Remove the cast OP from the fp32 param to the fp16 param. Remove fp32 param cast will persist fp16 param, it + will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large. + Examples: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index f6741b165ce07..46b6766ab1ce1 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole +from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op from paddle.fluid import core, unique_name __all__ = [] @@ -84,7 +84,7 @@ def _create_offload_var(self, var_name, offload_var_name, blocks): dtype=var.dtype, persistable=True) - def offload_fp32param(self, block, startup_block): + def offload_fp32param(self, block, startup_block, offload=True): """ (p_fp16) = cast(p) (p_fp16_recompute) = cast(p) @@ -113,11 +113,12 @@ def remove_param(input_name): # step1: record param for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] param_to_idx[param] = idx - # step2: remove param which can't offload + # step2: remove param which can't offload and + # record param->fp16param, fp16param->recompute_var for idx, op in enumerate(block.ops): if is_optimizer_op(op): break @@ -125,7 +126,7 @@ def remove_param(input_name): if input_name not in param_to_idx: continue - # param is real used by fp32 op + # param which will be used by fp32 op if op.type != 'cast': remove_param(input_name) continue @@ -154,17 +155,19 @@ def remove_param(input_name): # step3: main_block add offload, cast op # change recompute to fp16, remove cast(param) to fp16 for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] if param not in param_to_idx: continue # step3.1: create offload_var offload_var_name = self._get_offload_var_name(param) param_name_to_offload_name[param] = offload_var_name - self._create_offload_var(param, offload_var_name, - [block, startup_block]) + if offload: + self._create_offload_var(param, offload_var_name, + [block, startup_block]) - # step3.2: insert cast op and offload op - self._insert_offload_op(block, idx + 1, param, offload_var_name) + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, + offload_var_name) assert param in param_to_fp16 fp16_param_name = param_to_fp16[param] @@ -173,8 +176,9 @@ def remove_param(input_name): self._insert_cast_op(block, idx + 1, param, param_to_fp16[param]) - # step3.3: insert fetch op - self._insert_fetch_op(block, idx, offload_var_name, param) + if offload: + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) continue # step3.4: remove cast op @@ -206,9 +210,10 @@ def remove_param(input_name): if out_name in param_name_to_offload_name: var_name = out_name - offload_var_name = param_name_to_offload_name[var_name] - self._insert_offload_op(startup_block, idx + 1, var_name, - offload_var_name) + if offload: + offload_var_name = param_name_to_offload_name[var_name] + self._insert_offload_op(startup_block, idx + 1, + var_name, offload_var_name) self._insert_cast_op(startup_block, idx + 1, var_name, param_to_fp16[var_name]) @@ -217,6 +222,19 @@ def remove_param(input_name): block._sync_with_cpp() startup_block._sync_with_cpp() + def remove_fp32param_cast(self, block, startup_block): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (pout,) = adam(p) + (p_fp16) = cast(p) + """ + self.offload_fp32param(block, startup_block, offload=False) + def offload(self, block, startup_block): """ (m1, m2) = prefetch(m1@offload, m2@offload) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 93901b38873b9..34a50c73d5f41 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -400,7 +400,14 @@ def _apply_optimize_offload_pass(self): logger.info("Sharding with optimize offload !") offload_helper = OffloadHelper() offload_helper.offload(main_block, startup_block) + # The remove_fp32param_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) + elif sharding_configs['remove_param_cast']: + logger.info("Sharding with remove fp32param cast !") + # NOTE(wangxi): Remove fp32 param cast will persist fp16 param, it + # will take more memory, but will be faster. Trade space for time. + offload_helper = OffloadHelper() + offload_helper.remove_fp32param_cast(main_block, startup_block) def _dump_program_for_debug(self): main_block = self._main_program.global_block() @@ -444,6 +451,7 @@ def minimize_impl(self, # loss div dp_degree self._insert_loss_grad_scale_op() + # apply optimize offload or remove fp32param cast self._apply_optimize_offload_pass() # step6: (optional) sharding gradient merge From d03712971902e4654e0c8cc28f62e6195971363a Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 17 Aug 2021 16:24:50 +0800 Subject: [PATCH 2/5] add test --- .../test_fleet_sharding_meta_optimizer.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index d70a58c7d8ab4..6f17cba1ba575 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -859,6 +859,100 @@ def test_hybrid_with_sharding_pp_amp_fp16allreduce_in_optimize(self): self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + def test_hybrid_with_pp_dp_amp_fp16allreduce_remove_param_cast(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "remove_param_cast": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random', + 'cast', 'fill_constant', 'cast', 'uniform_random', 'cast', + 'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast', + 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum', + 'cast' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main() From 4173793cf470dc742a68fd328dd081fc3d5d3014 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 17 Aug 2021 16:36:05 +0800 Subject: [PATCH 3/5] rename remove_param_cast to optimize_cast --- paddle/fluid/framework/distributed_strategy.proto | 2 +- .../distributed/fleet/base/distributed_strategy.py | 2 +- .../fleet/meta_optimizers/sharding/offload_helper.py | 2 +- .../fleet/meta_optimizers/sharding_optimizer.py | 12 ++++++------ .../unittests/test_fleet_sharding_meta_optimizer.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b4a3caff1675b..546b9d2601df5 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -42,7 +42,7 @@ message ShardingConfig { optional bool optimize_offload = 9 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional int32 pp_degree = 11 [ default = 1 ]; - optional bool remove_param_cast = 12 [ default = false ]; + optional bool optimize_cast = 12 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f72a9df1653a2..849b7fb1da14f 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -888,7 +888,7 @@ def sharding_configs(self): pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. - remove_param_cast(bool, optional): [Hybrid parallelism ONLY] Remove the cast OP from the fp32 param to the fp16 param. Remove fp32 param cast will persist fp16 param, it + optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimize. optimize_cast will persist fp16 param, it will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large. diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 46b6766ab1ce1..a96705b09e835 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -222,7 +222,7 @@ def remove_param(input_name): block._sync_with_cpp() startup_block._sync_with_cpp() - def remove_fp32param_cast(self, block, startup_block): + def cast_fp32param_in_optimize(self, block, startup_block): """ (p_fp16) = cast(p) (p_fp16_recompute) = cast(p) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 34a50c73d5f41..5c2f24054f835 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -400,14 +400,14 @@ def _apply_optimize_offload_pass(self): logger.info("Sharding with optimize offload !") offload_helper = OffloadHelper() offload_helper.offload(main_block, startup_block) - # The remove_fp32param_cast is already included in offload_fp32param + # The optimize_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) - elif sharding_configs['remove_param_cast']: - logger.info("Sharding with remove fp32param cast !") - # NOTE(wangxi): Remove fp32 param cast will persist fp16 param, it + elif sharding_configs['optimize_cast']: + logger.info("Sharding with optimize cast !") + # NOTE(wangxi): optimize_cast will persist fp16 param, it # will take more memory, but will be faster. Trade space for time. offload_helper = OffloadHelper() - offload_helper.remove_fp32param_cast(main_block, startup_block) + offload_helper.cast_fp32param_in_optimize(main_block, startup_block) def _dump_program_for_debug(self): main_block = self._main_program.global_block() @@ -451,7 +451,7 @@ def minimize_impl(self, # loss div dp_degree self._insert_loss_grad_scale_op() - # apply optimize offload or remove fp32param cast + # apply optimize offload or optimize cast self._apply_optimize_offload_pass() # step6: (optional) sharding gradient merge diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 6f17cba1ba575..e5cfc78fb4380 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -859,7 +859,7 @@ def test_hybrid_with_sharding_pp_amp_fp16allreduce_in_optimize(self): self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) - def test_hybrid_with_pp_dp_amp_fp16allreduce_remove_param_cast(self): + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self): train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( ) avg_cost, strategy = self.pp_net(train_prog, startup_prog) @@ -871,7 +871,7 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_remove_param_cast(self): "mp_degree": 1, "pp_degree": 2, "dp_degree": 2, - "remove_param_cast": True, + "optimize_cast": True, } strategy.pipeline = True strategy.pipeline_configs = { From 29459547f3e33537bbe556df922a81e6483dc7f5 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 17 Aug 2021 16:39:09 +0800 Subject: [PATCH 4/5] fix --- python/paddle/distributed/fleet/base/distributed_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 849b7fb1da14f..d43292ddbd32e 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -888,7 +888,7 @@ def sharding_configs(self): pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. - optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimize. optimize_cast will persist fp16 param, it + optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large. From 8ca34b177435a44ad930b5f37347767ce6b8cd7a Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 17 Aug 2021 19:16:56 +0800 Subject: [PATCH 5/5] fix coverage --- .../test_fleet_sharding_meta_optimizer.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index e5cfc78fb4380..5a981a470cb4e 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -953,6 +953,103 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self): self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "optimize_offload": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast', + 'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant', + 'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy', + 'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast', + 'memcpy', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum', + 'memcpy', 'momentum', 'cast', 'memcpy' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main()