Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi Auto] Refactor Completion Mechanism (Part2) #57751

Merged
merged 34 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4f82e54
first commit
JZ-LIANG Sep 15, 2023
d5fbf20
framework
JZ-LIANG Sep 15, 2023
8c3fbeb
matmul done
JZ-LIANG Sep 18, 2023
ec99d92
elementwise done
JZ-LIANG Sep 18, 2023
149a12f
adapt done
JZ-LIANG Sep 20, 2023
dae7e89
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 20, 2023
6da5eb1
polish code
JZ-LIANG Sep 20, 2023
47d8194
revise logging
JZ-LIANG Sep 21, 2023
202c677
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 21, 2023
4647fa2
revise log
JZ-LIANG Sep 21, 2023
9c9c68a
update doc
JZ-LIANG Sep 21, 2023
36a9ec4
enable LN unitest
JZ-LIANG Sep 22, 2023
aa339a5
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 22, 2023
ab5d11a
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 22, 2023
37b372c
precommit
JZ-LIANG Sep 22, 2023
375d57c
embedding & reshape
JZ-LIANG Sep 22, 2023
7708672
transpose
JZ-LIANG Sep 22, 2023
a40496e
split
JZ-LIANG Sep 22, 2023
a0bc730
default
JZ-LIANG Sep 22, 2023
054554b
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 26, 2023
154b027
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 27, 2023
050a02d
Merge remote-tracking branch 'upstream/develop' into semi-auto/comple…
JZ-LIANG Sep 27, 2023
1c806c1
bugfix in split
JZ-LIANG Sep 27, 2023
22ed519
bugfix
JZ-LIANG Sep 27, 2023
8abaa8b
support partial
JZ-LIANG Sep 27, 2023
f89e841
enable partial on static mode (for embed only)
JZ-LIANG Sep 27, 2023
9ffe43f
program print partial
JZ-LIANG Sep 27, 2023
9fe7d02
remove t
JZ-LIANG Sep 27, 2023
a6006bf
remove t
JZ-LIANG Sep 27, 2023
4d6e55d
remove t
JZ-LIANG Sep 27, 2023
7b247f6
bugfix: matmul grad allow empty grad when stop gradient
JZ-LIANG Oct 9, 2023
342a040
bugfix reshape unitest
JZ-LIANG Oct 9, 2023
dd351b3
bugfix for dist op cost test
JZ-LIANG Oct 9, 2023
c44566d
revise dist tensor partial print
JZ-LIANG Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ namespace distributed {
PD_REGISTER_SPMD_RULE(matmul,
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));
PD_REGISTER_SPMD_RULE(matmul_v2, // static mode
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
elementwise_unary,
Expand All @@ -67,6 +70,10 @@ PD_REGISTER_SPMD_RULE(
unsqueeze,
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
default_,
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));

// replicated rule /* for unittest */
PD_REGISTER_SPMD_RULE(
Expand Down Expand Up @@ -465,6 +472,10 @@ PD_REGISTER_SPMD_RULE(
sum,
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
reduce_sum, // static
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));

// layer_norm
PD_REGISTER_SPMD_RULE(
Expand All @@ -476,6 +487,9 @@ PD_REGISTER_SPMD_RULE(
PD_REGISTER_SPMD_RULE(reshape,
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse));
PD_REGISTER_SPMD_RULE(reshape2,
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse));

// embedding rule
PD_REGISTER_SPMD_RULE(
Expand All @@ -501,6 +515,10 @@ PD_REGISTER_SPMD_RULE(
transpose,
PD_INFER_SPMD(phi::distributed::TransposeInferSpmd),
PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
transpose2,
PD_INFER_SPMD(phi::distributed::TransposeInferSpmd),
PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse));

} // namespace distributed
} // namespace phi
7 changes: 6 additions & 1 deletion python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _can_apply_infer_spmd_rule(dist_op):
enable = True if enable == 'true' else False
enable = bool(enable)

# TODO remove me. ops to be adapted: lookup_table_v2, reshape2, split, transpose2,
# TODO remove me. ops to be adapted: squeeze2
__adapted_ops__ = [
"matmul_v2",
"elementwise_div",
Expand All @@ -143,6 +143,11 @@ def _can_apply_infer_spmd_rule(dist_op):
"dropout",
"reduce_sum",
"layer_norm",
"lookup_table_v2",
"reshape2",
"transpose2",
"split",
"unsqueeze2",
]
op_type = dist_op.serial_op.type
return enable and contains_spmd_rule(op_type) and op_type in __adapted_ops__
Expand Down
26 changes: 22 additions & 4 deletions python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,17 @@ def __str__(self):
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping

# partial
input_dist_attr = self.dist_attr.get_input_dist_attr(arg_name)
partial_dims = sorted(input_dist_attr._partial_dims())

str += "; {}'s dims_mapping (input, {}, {}): {}, partial on dims: {}".format(
arg_name,
annotated_str,
is_parameter_str,
dims_mapping,
partial_dims,
)

for arg_name in self.serial_op.desc.output_arg_names():
Expand All @@ -174,8 +183,17 @@ def __str__(self):
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping

# partial
output_dist_attr = self.dist_attr.get_output_dist_attr(arg_name)
partial_dims = sorted(output_dist_attr._partial_dims())

str += "; {}'s dims_mapping (output, {}, {}): {}, partial on dims: {}".format(
arg_name,
annotated_str,
is_parameter_str,
dims_mapping,
partial_dims,
)

str += ", dist_impl idx: {} , dist_impl type {} }}".format(
Expand Down
35 changes: 30 additions & 5 deletions python/paddle/distributed/auto_parallel/static/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,13 @@ def merge_forward_backward_dims_mapping(fw_results, bw_results):


def update_op_dims_mapping(
dist_op,
input_arg_names,
infered_input_dims_mappings,
output_arg_names,
infered_output_dims_mappings,
dist_op, input_arg_names, output_arg_names, fw_results, bw_results
):
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

op_dist_attr = dist_op.dist_attr
changed = False
assert len(input_arg_names) == len(
Expand Down Expand Up @@ -661,6 +662,7 @@ def update_op_dims_mapping(
op_dist_attr.set_input_dims_mapping(
input_arg_names[i], infered_dims_mapping
)
# TODO support partial for inputs

for i in range(len(output_arg_names)):
original_dims_mapping = op_dist_attr.get_output_dims_mapping(
Expand All @@ -683,6 +685,29 @@ def update_op_dims_mapping(
output_arg_names[i], infered_dims_mapping
)

# NOTE in partial stage-I, we infer partial for output in infer_forward only
output_dist_attr = op_dist_attr.get_output_dist_attr(
output_arg_names[i]
)
output_idx = output_arg_names.index(output_arg_names[i])
if (
fw_results[1][output_idx]._partial_dims()
!= output_dist_attr._partial_dims()
):
_logger.info(
"Changed: Op [{}], tensor name [{}], Original partial on [{}], Infered partial on [{}]".format(
dist_op.serial_op.type,
output_arg_names[i],
output_dist_attr._partial_dims(),
fw_results[1][output_idx]._partial_dims(),
)
)
output_dist_attr._clean_partial_status()
output_dist_attr._set_partial_dims(
list(fw_results[1][0]._partial_dims())
)
changed = True

return changed


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License


from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole

from ..completion import get_phi_spmd_rule
from ..cost import (
_g_op_cost_factory,
build_comp_costs_from_descs,
Expand All @@ -26,16 +28,19 @@
_get_comm_group,
_get_corresponding_rank,
compute_compatible_dim_mapping,
get_dist_tensor_spec,
is_prim_op,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
get_default_distributed_operator_impl,
gradient_synchronization,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
)

__op_not_need_param_init__ = ["while", "cond"]
Expand Down Expand Up @@ -97,6 +102,61 @@ class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)

@staticmethod
def update_dims_mapping(dist_op):
# step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args)

op_desc = dist_op.serial_op.desc
input_arg_names = op_desc.input_arg_names()
output_arg_names = op_desc.output_arg_names()

num_inputs = len(input_arg_names)
input_specs = []
for i in range(num_inputs):
assert not is_parameter_related(
input_arg_names[i]
), "input {} of op {} is parameter, op should not use default rule.".format(
input_arg_names[i], str(dist_op.serial_op)
)
input_specs.append(
get_dist_tensor_spec(dist_op, input_arg_names[i])
)
num_outputs = len(output_arg_names)
output_specs = []
for i in range(num_outputs):
assert not is_parameter_related(
output_arg_names[i]
), "output {} of op {} is parameter, op should not use default rule.".format(
output_arg_names[i], str(dist_op.serial_op)
)
output_specs.append(
get_dist_tensor_spec(dist_op, output_arg_names[i], False)
)

# step2: infer spmd
rule = get_phi_spmd_rule("default_")
# tensor order following order in PHI defition
fw_results = rule.infer_forward(input_specs, output_specs)
bw_results = rule.infer_backward(input_specs, output_specs)

# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op, input_arg_names, output_arg_names, fw_results, bw_results
)

return changed

@staticmethod
def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr):
# all op use default dist operator impl.
op_dist_attr = dist_op.dist_attr
default_impl = get_default_distributed_operator_impl()
op_dist_attr.impl_type = default_impl.type
op_dist_attr.impl_idx = default_impl.idx

return False


register_distributed_operator_impl_container(DistributedDefault("default"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,18 @@ def update_dims_mapping(dist_op):
fw_results = rule.infer_forward(x_spec)
bw_results = rule.infer_backward(x_spec, output_spec)

# step3: merge fw & bw results
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

# step4: update dist_attr
# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op,
[x_name],
infered_input_dims_mappings,
[out_name],
infered_output_dims_mappings,
dist_op, [x_name], [out_name], fw_results, bw_results
)

# step5: update mask and seed dropout special
if changed:
(
_,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)
dist_op.dist_attr.set_output_dims_mapping(
mask_name, infered_output_dims_mappings[0]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
get_default_distributed_operator_impl,
is_elementwise_op,
is_parameter_related,
merge_forward_backward_dims_mapping,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
Expand Down Expand Up @@ -77,20 +76,10 @@ def update_dims_mapping(dist_op):
fw_results = rule.infer_forward(*input_specs)
bw_results = rule.infer_backward(*input_specs, output_spec)

# step3: merge fw & bw results
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

# step4: update dist_attr
# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op,
input_arg_names,
infered_input_dims_mappings,
[output_arg_name],
infered_output_dims_mappings,
dist_op, input_arg_names, [output_arg_name], fw_results, bw_results
)

return changed
Expand Down
Loading