Skip to content

Commit

Permalink
[Semi Auto] Refactor Completion Mechanism (Part2) (PaddlePaddle#57751)
Browse files Browse the repository at this point in the history
* first commit

* framework

* matmul done

* elementwise done

* adapt done

* polish code

* revise logging

* revise log

* update doc

* enable LN unitest

* precommit

* embedding & reshape

* transpose

* split

* default

* bugfix in split

* bugfix

* support partial

* enable partial on static mode (for embed only)

* program print partial

* remove t

* remove t

* remove t

* bugfix: matmul grad allow empty grad when stop gradient

* bugfix reshape unitest

* bugfix for dist op cost test

* revise dist tensor partial print
  • Loading branch information
JZ-LIANG authored and Frida-a committed Oct 14, 2023
1 parent 6e18b37 commit 9ed6279
Show file tree
Hide file tree
Showing 19 changed files with 433 additions and 125 deletions.
18 changes: 18 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ namespace distributed {
PD_REGISTER_SPMD_RULE(matmul,
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));
PD_REGISTER_SPMD_RULE(matmul_v2, // static mode
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
elementwise_unary,
Expand All @@ -68,6 +71,10 @@ PD_REGISTER_SPMD_RULE(
unsqueeze,
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
default_,
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));

// replicated rule /* for unittest */
PD_REGISTER_SPMD_RULE(
Expand Down Expand Up @@ -466,6 +473,10 @@ PD_REGISTER_SPMD_RULE(
sum,
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
reduce_sum, // static
PD_INFER_SPMD(phi::distributed::ReductionInferSpmd),
PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse));

// layer_norm
PD_REGISTER_SPMD_RULE(
Expand All @@ -477,6 +488,9 @@ PD_REGISTER_SPMD_RULE(
PD_REGISTER_SPMD_RULE(reshape,
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse));
PD_REGISTER_SPMD_RULE(reshape2,
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd),
PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse));

// embedding rule
PD_REGISTER_SPMD_RULE(
Expand All @@ -502,6 +516,10 @@ PD_REGISTER_SPMD_RULE(
transpose,
PD_INFER_SPMD(phi::distributed::TransposeInferSpmd),
PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse));
PD_REGISTER_SPMD_RULE(
transpose2,
PD_INFER_SPMD(phi::distributed::TransposeInferSpmd),
PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse));

// softmax rule
PD_REGISTER_SPMD_RULE(softmax,
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _can_apply_infer_spmd_rule(dist_op):
enable = True if enable == 'true' else False
enable = bool(enable)

# TODO remove me. ops to be adapted: lookup_table_v2, reshape2, split, transpose2,
# TODO remove me. ops to be adapted: squeeze2
__adapted_ops__ = [
"matmul_v2",
"elementwise_div",
Expand All @@ -143,6 +143,11 @@ def _can_apply_infer_spmd_rule(dist_op):
"dropout",
"reduce_sum",
"layer_norm",
"lookup_table_v2",
"reshape2",
"transpose2",
"split",
"unsqueeze2",
]
op_type = dist_op.serial_op.type
return enable and contains_spmd_rule(op_type) and op_type in __adapted_ops__
Expand Down
26 changes: 22 additions & 4 deletions python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,17 @@ def __str__(self):
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping

# partial
input_dist_attr = self.dist_attr.get_input_dist_attr(arg_name)
partial_dims = sorted(input_dist_attr._partial_dims())

str += "; {}'s dims_mapping (input, {}, {}): {}, partial on dims: {}".format(
arg_name,
annotated_str,
is_parameter_str,
dims_mapping,
partial_dims,
)

for arg_name in self.serial_op.desc.output_arg_names():
Expand All @@ -174,8 +183,17 @@ def __str__(self):
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping

# partial
output_dist_attr = self.dist_attr.get_output_dist_attr(arg_name)
partial_dims = sorted(output_dist_attr._partial_dims())

str += "; {}'s dims_mapping (output, {}, {}): {}, partial on dims: {}".format(
arg_name,
annotated_str,
is_parameter_str,
dims_mapping,
partial_dims,
)

str += ", dist_impl idx: {} , dist_impl type {} }}".format(
Expand Down
35 changes: 30 additions & 5 deletions python/paddle/distributed/auto_parallel/static/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,12 +622,13 @@ def merge_forward_backward_dims_mapping(fw_results, bw_results):


def update_op_dims_mapping(
dist_op,
input_arg_names,
infered_input_dims_mappings,
output_arg_names,
infered_output_dims_mappings,
dist_op, input_arg_names, output_arg_names, fw_results, bw_results
):
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

op_dist_attr = dist_op.dist_attr
changed = False
assert len(input_arg_names) == len(
Expand Down Expand Up @@ -661,6 +662,7 @@ def update_op_dims_mapping(
op_dist_attr.set_input_dims_mapping(
input_arg_names[i], infered_dims_mapping
)
# TODO support partial for inputs

for i in range(len(output_arg_names)):
original_dims_mapping = op_dist_attr.get_output_dims_mapping(
Expand All @@ -683,6 +685,29 @@ def update_op_dims_mapping(
output_arg_names[i], infered_dims_mapping
)

# NOTE in partial stage-I, we infer partial for output in infer_forward only
output_dist_attr = op_dist_attr.get_output_dist_attr(
output_arg_names[i]
)
output_idx = output_arg_names.index(output_arg_names[i])
if (
fw_results[1][output_idx]._partial_dims()
!= output_dist_attr._partial_dims()
):
_logger.info(
"Changed: Op [{}], tensor name [{}], Original partial on [{}], Infered partial on [{}]".format(
dist_op.serial_op.type,
output_arg_names[i],
output_dist_attr._partial_dims(),
fw_results[1][output_idx]._partial_dims(),
)
)
output_dist_attr._clean_partial_status()
output_dist_attr._set_partial_dims(
list(fw_results[1][0]._partial_dims())
)
changed = True

return changed


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License


from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole

from ..completion import get_phi_spmd_rule
from ..cost import (
_g_op_cost_factory,
build_comp_costs_from_descs,
Expand All @@ -26,16 +28,19 @@
_get_comm_group,
_get_corresponding_rank,
compute_compatible_dim_mapping,
get_dist_tensor_spec,
is_prim_op,
set_dist_op_desc_original_id,
)
from .common import (
DistributedOperatorImpl,
DistributedOperatorImplContainer,
get_default_distributed_operator_impl,
gradient_synchronization,
is_parameter_related,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
)

__op_not_need_param_init__ = ["while", "cond"]
Expand Down Expand Up @@ -97,6 +102,61 @@ class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)

@staticmethod
def update_dims_mapping(dist_op):
# step1: prepare inputs need for rule (order args as PHI definition and filter out unnecessary args)

op_desc = dist_op.serial_op.desc
input_arg_names = op_desc.input_arg_names()
output_arg_names = op_desc.output_arg_names()

num_inputs = len(input_arg_names)
input_specs = []
for i in range(num_inputs):
assert not is_parameter_related(
input_arg_names[i]
), "input {} of op {} is parameter, op should not use default rule.".format(
input_arg_names[i], str(dist_op.serial_op)
)
input_specs.append(
get_dist_tensor_spec(dist_op, input_arg_names[i])
)
num_outputs = len(output_arg_names)
output_specs = []
for i in range(num_outputs):
assert not is_parameter_related(
output_arg_names[i]
), "output {} of op {} is parameter, op should not use default rule.".format(
output_arg_names[i], str(dist_op.serial_op)
)
output_specs.append(
get_dist_tensor_spec(dist_op, output_arg_names[i], False)
)

# step2: infer spmd
rule = get_phi_spmd_rule("default_")
# tensor order following order in PHI defition
fw_results = rule.infer_forward(input_specs, output_specs)
bw_results = rule.infer_backward(input_specs, output_specs)

# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op, input_arg_names, output_arg_names, fw_results, bw_results
)

return changed

@staticmethod
def mapping_to_dist_operator_impl(dist_op, original_op_dist_attr):
# all op use default dist operator impl.
op_dist_attr = dist_op.dist_attr
default_impl = get_default_distributed_operator_impl()
op_dist_attr.impl_type = default_impl.type
op_dist_attr.impl_idx = default_impl.idx

return False


register_distributed_operator_impl_container(DistributedDefault("default"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,18 @@ def update_dims_mapping(dist_op):
fw_results = rule.infer_forward(x_spec)
bw_results = rule.infer_backward(x_spec, output_spec)

# step3: merge fw & bw results
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

# step4: update dist_attr
# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op,
[x_name],
infered_input_dims_mappings,
[out_name],
infered_output_dims_mappings,
dist_op, [x_name], [out_name], fw_results, bw_results
)

# step5: update mask and seed dropout special
if changed:
(
_,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)
dist_op.dist_attr.set_output_dims_mapping(
mask_name, infered_output_dims_mappings[0]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
get_default_distributed_operator_impl,
is_elementwise_op,
is_parameter_related,
merge_forward_backward_dims_mapping,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
update_op_dims_mapping,
Expand Down Expand Up @@ -77,20 +76,10 @@ def update_dims_mapping(dist_op):
fw_results = rule.infer_forward(*input_specs)
bw_results = rule.infer_backward(*input_specs, output_spec)

# step3: merge fw & bw results
(
infered_input_dims_mappings,
infered_output_dims_mappings,
) = merge_forward_backward_dims_mapping(fw_results, bw_results)

# step4: update dist_attr
# step3: update dist_attr
# tensor order following order in PHI defition
changed = update_op_dims_mapping(
dist_op,
input_arg_names,
infered_input_dims_mappings,
[output_arg_name],
infered_output_dims_mappings,
dist_op, input_arg_names, [output_arg_name], fw_results, bw_results
)

return changed
Expand Down
Loading

0 comments on commit 9ed6279

Please sign in to comment.