Skip to content

Commit

Permalink
special deal with the inplace case in auto parallel pass. (#65060)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Jun 16, 2024
1 parent 448c929 commit 02be1a8
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 18 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@
'c_allreduce_avg',
'c_allreduce_max',
'c_allreduce_min',
'c_allreduce_sum',
'c_allreduce_prod',
'c_embedding',
'c_identity',
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3866,6 +3866,7 @@ void AssignOut_Op::Build(pir::Builder &builder,
std::vector<pir::Type> argument_outputs =
AssignOut_Op::InferMeta(argument_inputs, &argument_attributes);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
argument.AddAttributes(argument_attributes);
constexpr char kStopGradientAttrName[] = "stop_gradient";
auto stop_gradient0 =
argument.inputs[0].attribute<pir::BoolAttribute>(kStopGradientAttrName);
Expand Down Expand Up @@ -3970,6 +3971,22 @@ std::vector<pir::Type> AssignOut_Op::InferMeta(
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
#ifdef PADDLE_WITH_DISTRIBUTE
// Auto Parallel condition
if (auto dist_type = input_values[1].type().dyn_cast<DistTypeInterface>()) {
ProcessMeshAttribute op_mesh = dist_type.process_mesh_attr();
auto ctx = pir::IrContext::Instance();
std::vector<pir::Attribute> dist_operand_attrs{
dist_type.tensor_dist_attr(),
dist_type.tensor_dist_attr(),
},
dist_result_attrs{dist_type.tensor_dist_attr()};
argument_outputs.push_back(dist_type);
(*p_attributes)[kAttrOpDistAttr] = OperationDistAttribute::get(
ctx, op_mesh, dist_operand_attrs, dist_result_attrs);
return argument_outputs;
}
#endif
argument_outputs.push_back(out_dense_tensor_type);

return argument_outputs;
Expand Down
46 changes: 45 additions & 1 deletion python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,50 @@ def apply_partition_pass(program):
assert len(op.operands()) == len(
op.dist_attr.operands()
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"
assert len(op.results()) == len(
op.dist_attr.results()
), f"The number of results and the number of op_dist_attr's results are not equal in op: {op}"
# deal with inplace value
for out_idx, in_idx in paddle.core.pir.get_op_inplace_info(op).items():
operand = op.operand(in_idx)
operand_attr = op.dist_attr.operand(in_idx)
prev_var = operand.source()
if not prev_var.is_dist() or operand_attr == prev_var.dist_attr():
continue
assert (
not prev_var.is_combine()
), f"The current partition pass not support inplace value of {op} is tensor list."
operand_attr = operand_attr.as_tensor_dist_attr()
# reshard input
paddle.pir.set_insertion_point(op)
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
operand.set_source(reshard_var)

result = op.result(out_idx)
result_attr = op.dist_attr.result(out_idx).as_tensor_dist_attr()
assert (
operand_attr == result_attr
), f"For inplace value, The operend dist attr should be equal to result dist attr , please check your infer_spmd func of {op}"

# reshard output
paddle.pir.set_insertion_point_after(op)
old_dist_attr = result.dist_attr()
result.update_dist_attr(result_attr)

# reshard output to assign out input
reshard_var_1 = paddle._C_ops.reshard_v2(
result, prev_var.dist_attr()
)
paddle.assign(reshard_var_1, prev_var)

if old_dist_attr == result.dist_attr():
continue
reshard_var_2 = reshard_var_1
if old_dist_attr != reshard_var_1.dist_attr():
reshard_var_2 = paddle._C_ops.reshard_v2(result, old_dist_attr)
result.replace_all_uses_with(reshard_var_1)
reshard_var_1.get_defining_op().operand(0).set_source(result)
reshard_var_2.get_defining_op().operand(0).set_source(result)

for operand, attr in zip(op.operands(), op.dist_attr.operands()):
prev_var = operand.source()
Expand Down Expand Up @@ -187,7 +231,7 @@ def remove_other_rank_op_pass(dist_program):


# Note: this is the pass in the dense program
comm_ops = ["pd_op.c_allreduce_sum_", "pd_op.c_allgather"]
comm_ops = ["pd_op.c_allreduce_sum", "pd_op.c_allgather"]


def remove_unuseful_comm_op_pass(program):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
reduce_mean = True

group = new_process_group(sorted(src_mesh.process_ids))
reduced_value = paddle._C_ops.c_allreduce_sum_(
reduced_value = paddle._C_ops.c_allreduce_sum(
src_value, group.id, True, False
)

Expand Down
12 changes: 6 additions & 6 deletions test/auto_parallel/hybrid_strategy/pir_reshard_nd_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def run_pp_to_rr_case(self):
new_ops_name = [op.name() for op in dist_program.global_block().ops]

rank_id = dist.get_rank()
assert new_ops_name[-2] == "pd_op.c_allreduce_sum_"
assert new_ops_name[-1] == "pd_op.c_allreduce_sum_"
assert new_ops_name[-2] == "pd_op.c_allreduce_sum"
assert new_ops_name[-1] == "pd_op.c_allreduce_sum"

# check the first allreduce_sum
op = new_ops[-2]
Expand Down Expand Up @@ -151,11 +151,11 @@ def run_pr_to_rs_case(self):
new_ops_name = [op.name() for op in dist_program.global_block().ops]

rank_id = dist.get_rank()
assert "pd_op.c_allreduce_sum_" in new_ops_name
assert "pd_op.c_allreduce_sum" in new_ops_name
assert new_ops_name[-1] == "pd_op.slice"

# check the allreduce_sum
op = new_ops[new_ops_name.index("pd_op.c_allreduce_sum_")]
op = new_ops[new_ops_name.index("pd_op.c_allreduce_sum")]
if rank_id == 0 or rank_id == 2:
process_ids = [0, 2]
elif rank_id == 1 or rank_id == 3:
Expand Down Expand Up @@ -278,11 +278,11 @@ def run_ps_to_ps_case(self):

ops = dist_program.global_block().ops
op_names = [op.name() for op in ops]
assert "pd_op.c_allreduce_sum_" in op_names
assert "pd_op.c_allreduce_sum" in op_names
assert "pd_op.c_allgather" in op_names
assert "pd_op.slice" in op_names

allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum_")]
allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum")]
allgather_op = ops[op_names.index("pd_op.c_allgather")]
slice_op = ops[op_names.index("pd_op.slice")]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def run_pp_to_rr_case(self):
assert new_ops_name[2] == "pd_op.send_v2"
else:
assert new_ops_name[2] == "pd_op.recv_v2"
assert new_ops_name[-2] == "pd_op.c_allreduce_sum_"
assert new_ops_name[-1] == "pd_op.c_allreduce_sum_"
assert new_ops_name[-2] == "pd_op.c_allreduce_sum"
assert new_ops_name[-1] == "pd_op.c_allreduce_sum"

# check the first allreduce_sum
op = new_ops[-2]
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/pir/test_to_static_pir_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_to_static_program(self):
"pd_op.sgd_",
"pd_op.sgd_",
"pd_op.relu_grad",
"pd_op.c_allreduce_sum_",
"pd_op.c_allreduce_sum",
"pd_op.matmul_grad",
"pd_op.relu_grad",
"pd_op.matmul_grad",
Expand Down
6 changes: 3 additions & 3 deletions test/auto_parallel/reshard_p_to_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def run_pir_static_test_case(self):
'builtin.parameter',
'pd_op.data',
'dist_op.shard_tensor',
'pd_op.c_allreduce_sum_',
'pd_op.c_allreduce_sum',
],
)

for op in ops:
if op.name() == 'pd_op.c_allreduce_sum_':
if op.name() == 'pd_op.c_allreduce_sum':
# check op dist_attr
assert op.dist_attr.num_operands() == 1
assert op.dist_attr.num_results() == 1
Expand Down Expand Up @@ -167,7 +167,7 @@ def run_pir_to_static_test_case(self):
"pd_op.sgd_",
"pd_op.sgd_",
"pd_op.relu_grad",
"pd_op.c_allreduce_sum_",
"pd_op.c_allreduce_sum",
"pd_op.matmul_grad",
"pd_op.relu_grad",
"pd_op.matmul_grad",
Expand Down
6 changes: 3 additions & 3 deletions test/auto_parallel/reshard_p_to_r_cross_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run_pir_static_test_case(self):
'dist_op.shard_tensor',
'pd_op.send_v2',
'dist_op.reshard',
'pd_op.c_allreduce_sum_',
'pd_op.c_allreduce_sum',
]
else:
np.testing.assert_equal(main_program.num_ops(), 5)
Expand All @@ -106,7 +106,7 @@ def run_pir_static_test_case(self):
'pd_op.data',
'dist_op.shard_tensor',
'pd_op.recv_v2',
'pd_op.c_allreduce_sum_',
'pd_op.c_allreduce_sum',
]
np.testing.assert_equal(
ops,
Expand Down Expand Up @@ -141,7 +141,7 @@ def run_pir_static_test_case(self):
assert op_result_dist_attr.partial_status == {
0: paddle.distributed.ReduceType.kRedSum
}
elif op.name() == 'pd_op.c_allreduce_sum_':
elif op.name() == 'pd_op.c_allreduce_sum':
continue
# check op dist_attr
assert op.dist_attr.num_operands() == 1
Expand Down

0 comments on commit 02be1a8

Please sign in to comment.