Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Support Primitive operators with Data Parallel #42709

Merged
merged 9 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,3 +1250,70 @@ def complete_update_annotation(self, serial_main_program=None):
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
continue

def complete_prim_annotation(self, serial_main_program=None):
"""
fill default data parallel annotation for program with primitive operators.

Arguments:
serial_main_program: partial annotated serial_main_program.
Returns:
serial_main_program: completed annotated serial_main_program.
"""
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context.serial_main_program = serial_main_program

import time

start_time = time.time()
self._dist_context._is_initialized = True

start_time = time.time()
self._dist_context._init_dist_attr_for_program()

start_time = time.time()
self._init_global_mesh_for_program()

# Do the validation check and amend some completion
start_time = time.time()
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()

def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_ranks = get_world_process_group().ranks

for block in self._dist_context._serial_main_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
dist_tensor = self._dist_context.get_dist_tensor_for_program(
tensor)
assert dist_tensor is not None
dist_tensor.dist_attr.process_mesh = world_ranks
for op in block.ops:
# Copy the distributed operators in the default context
dist_op = self._dist_context.get_dist_op_for_program(op)
assert dist_op is not None
dist_op.dist_attr.process_mesh = world_ranks

# Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
if op_dist_impls is not None:
backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
for op_dist_impl in op_dist_impls:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
dist_op.dist_attr.impl_type = "default"
else:
dist_op.dist_attr.impl_type = op_dist_impl.type
# op_dist_attr.impl_type = op_dist_impl.type
dist_op.dist_attr.impl_idx = op_dist_impl.idx
break
else:
dist_op.dist_attr = backup_op_dist_attr
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_p
7 changes: 4 additions & 3 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@


def is_elementwise_op(op_type):
for eltwise_op in _g_elementwise_ops:
if eltwise_op in op_type:
return True
if op_type in _g_elementwise_ops:
return True
if "elementwise" in op_type:
return True
return False


Expand Down
63 changes: 57 additions & 6 deletions python/paddle/distributed/auto_parallel/operators/dist_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import is_valid_list_index, is_prim_op
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
Expand All @@ -35,6 +35,55 @@
__op_not_need_param_init__ = ["while", "cond"]


def prim_operator_data_parallel_functor(ctx, src_op):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block

var_name = src_op.output_arg_names[0]
if var_name in ctx.grads_params:
assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format(
var_name)
ctx.synced_gradient.add(var_name)
sync_group = new_process_group(ctx.data_parallel_group)

allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})

param = ctx.grads_params[var_name]
startup_block = dist_op_context.startup_block
new_op = startup_block.append_op(
type='c_broadcast',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

grad_var = main_block.var(var_name)
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
dist_attr = ctx.get_op_dist_attr_for_program(src_op)
process_mesh = dist_attr.process_mesh
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, op_attr)

return


class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type)
Expand Down Expand Up @@ -292,7 +341,6 @@ def update_dims_mapping(self, dist_op):

@staticmethod
def forward(ctx, *args, **kwargs):

dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
Expand All @@ -315,15 +363,20 @@ def forward(ctx, *args, **kwargs):
output_name)

# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

main_block._sync_with_cpp()
# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
assert is_prim_op(src_op)
prim_operator_data_parallel_functor(ctx, src_op)
return

# param initialization sync
if src_op.type in __op_not_need_param_init__:
Expand Down Expand Up @@ -373,8 +426,6 @@ def forward(ctx, *args, **kwargs):
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)

startup_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):

Expand Down
151 changes: 151 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl, is_parameter_related
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank


class DistributedReducePrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type)


register_distributed_operator_impl_container(
DistributedReducePrimtive("reduce_p"))


# Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr

return len(op_desc.input_arg_names()) == 1

def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
outputs = op_desc.output_arg_names()

if len(outputs) != 1:
return False

output_name = outputs[0]
output_var = dist_op.serial_op.block.var(output_name)
if output_var.shape != (1, ):
return False

return True

def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr

return self.is_input_compatible(dist_op) and self.is_output_compatible(
dist_op)

def update_dims_mapping(self, dist_op):
changed = False

return changed

@staticmethod
def forward(ctx, *args, **kwargs):

dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id

# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)

# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

# batch dimension synchronization
var_name = src_op.output_arg_names[0]
sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [var_name]},
outputs={'Out': [var_name]},
attrs={
'ring_id': sync_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

# dist attr
var = main_block.var(var_name)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
new_op_attr = OperatorDistributedAttribute()
new_op_attr.process_mesh = op_dist_attr.process_mesh
new_op_attr.set_output_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
new_op_attr.set_input_dims_mapping(var.name,
tensor_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(allreduce_op, new_op_attr)

@staticmethod
def backward(ctx, *args, **kwargs):
raise RuntimeError(
"primitive operator does NOT have backward function, op type: {}".
format(str(op.type)))


register_distributed_operator_impl(
"reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p"))
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def partition_block(self, ref_block, target_block):
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
elif int(op.attr('op_role')) == 2:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
else:
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,10 @@ def is_loss_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))


def is_prim_op(op):
return op.type.endswith("_p")


def get_loss_op(block):
loss_ops = []
for op in block.ops:
Expand All @@ -1118,6 +1122,9 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
tensor_dist_attr.dims_mapping = dims_mapping
# TODO get global mesh group
tensor_dist_attr.process_mesh = process_mesh
if "mark_annotated" in kwargs and kwargs["mark_annotated"]:
tensor_dist_attr.mark_annotated("dims_mapping")
tensor_dist_attr.mark_annotated("process_mesh")
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
return tensor_dist_attr

Expand Down
Loading