Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Move reduce to opt stage #62157

Merged
merged 13 commits into from
Mar 5, 2024
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
set_field_default_config(
GRADIENT_MERGE, "dp_gradient_sync_after_accumulate", False
)

#########################################
# pipeline configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ def _apply_post_optimization(
)
dp_pass.apply([main_program], [startup_program], self._pass_context)

dp_gradient_sync_after_accumulate = (
self._strategy.gradient_merge.dp_gradient_sync_after_accumulate
)
if dp_gradient_sync_after_accumulate:
global_params_grads = params_grads

if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
Expand Down Expand Up @@ -485,7 +491,10 @@ def _apply_post_optimization(
if self.is_train and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
if dp_gradient_sync_after_accumulate:
config["params_grads"] = global_params_grads
else:
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config
)
Expand Down
68 changes: 65 additions & 3 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
Expand Down Expand Up @@ -260,6 +264,48 @@ def _append_gradient_merge_backward_op(
return new_params_grads, grad_to_gradient_merge


def _move_reduce_to_optimizer_ops_block(
main_program, optimize_ops_block, params_grads
):
main_block = main_program.global_block()
removed_op_idx = []
params_grads_name = [grad.name for _, grad in params_grads]

for idx, op in list(enumerate(main_block.ops)):
if is_data_parallel_reduce_op(op):
op_input_names = op.desc.input_arg_names()
if "@RENAME" in op_input_names[0]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是针对一些特殊case的处理,加个NOTE说明下吧

continue

reduce_op_desc = optimize_ops_block.desc._insert_op(
len(removed_op_idx)
)
reduce_op_desc.copy_from(op.desc)
reduce_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
removed_op_idx.append(idx)

if op.type in ["c_allreduce_sum", "c_reduce_sum"]:
scale_index = idx + 1
while scale_index < len(main_block.ops):
if is_data_parallel_scale_op(main_block.ops[scale_index]):
scale_op_desc = optimize_ops_block.desc._insert_op(
len(removed_op_idx)
)
scale_op_desc.copy_from(
main_block.ops[scale_index].desc
)
scale_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
removed_op_idx.append(scale_index)
break
scale_index += 1

for idx in removed_op_idx[::-1]:
main_block._remove_op(idx, sync=False)

main_block._sync_with_cpp()
return optimize_ops_block


def _create_cond_block_and_update_optimizer(
main_program,
cond_var,
Expand Down Expand Up @@ -390,7 +436,13 @@ def true_apply_gradient():


def parse_program(
main_program, startup_program, params_grads, k_steps, avg, dist_context
main_program,
startup_program,
params_grads,
k_steps,
avg,
dist_context,
dp_gradient_sync_after_accumulate,
):
# 1 remove optimizer_op from main_program
optimize_ops_block = _remove_and_get_optimizer_op(
Expand All @@ -405,10 +457,16 @@ def parse_program(
main_program, startup_program, params_grads, dist_context
)

# 3 create gradient_merge_cond
if dp_gradient_sync_after_accumulate:
# 3 move reduce op to optimizer_ops_block
optimize_ops_block = _move_reduce_to_optimizer_ops_block(
main_program, optimize_ops_block, params_grads
)

# 4 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)

# 4 create ConditionalBlock and append gradient merge optimizer ops
# 5 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(
main_program,
cond_var,
Expand Down Expand Up @@ -444,6 +502,9 @@ def _apply_single_impl(self, main_program, startup_program, context):
avg = self.get_attr("avg", False)
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
dp_gradient_sync_after_accumulate = self.get_attr(
"dp_gradient_sync_after_accumulate", False
)
with paddle.static.program_guard(main_program, startup_program):
parse_program(
main_program,
Expand All @@ -452,6 +513,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
k_steps,
avg,
dist_context,
dp_gradient_sync_after_accumulate,
)

main_program._sync_with_cpp()