From e52301b7d45688398ca05a5da1e7db94b3a03fe1 Mon Sep 17 00:00:00 2001 From: andsonder Date: Thu, 14 Mar 2024 15:14:56 +0000 Subject: [PATCH 1/6] extract split matmul_grad_op to pass_utils --- .../allreduce_matmul_grad_overlapping.py | 151 +++------------ .../paddle/distributed/passes/pass_utils.py | 178 ++++++++++++++++++ 2 files changed, 202 insertions(+), 127 deletions(-) diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index 89e6c20ad03c97..3123d09d121299 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -20,7 +20,7 @@ naive_set_dist_op_attr_for_program_by_mesh_and_mapping, ) from .pass_base import PassBase, register_pass -from .pass_utils import AutoParallelStreamType +from .pass_utils import _split_matmul_grad_to_matmul, AutoParallelStreamType logger = get_logger(logging.INFO) @@ -140,13 +140,11 @@ def _split_matmul_grad_and_multi_streaming_allreduce( # uninitialized tensors. Therefore, we move the cast operation to the back of the # second matmul operation to avoid this problem. skip_overlapping = False - moved_ops_idx = [] moved_ops_output = [] matmul_grad_output = matmul_grad_op.output('Y@GRAD')[0] for idx in range(matmul_grad_id + 1, allreduce_id): if matmul_grad_output in ops[idx].desc.input_arg_names(): - moved_ops_idx.append(idx) moved_ops_output.extend(ops[idx].desc.output_arg_names()) else: for input_name in ops[idx].desc.input_arg_names(): @@ -155,138 +153,37 @@ def _split_matmul_grad_and_multi_streaming_allreduce( if skip_overlapping: continue - - for i, idx in enumerate(moved_ops_idx): - op = ops[idx] - dist_attr = self.dist_context.get_op_dist_attr_for_program(op) - - op_inputs = op.desc.input_names() - op_outputs = op.desc.output_names() - - op_inputs = {name: op.input(name) for name in op_inputs} - op_outputs = {name: op.output(name) for name in op_outputs} - - op = block._insert_op_without_sync( - index=allreduce_id + 1 + i, - type=op.type, - inputs=op_inputs, - outputs=op_outputs, - attrs=op.all_attrs(), - ) - - self.dist_context.set_op_dist_attr_for_program(op, dist_attr) - - for i, idx in enumerate(moved_ops_idx): - block._remove_op(idx - i, sync=False) - allreduce_id -= 1 - - tran_x = matmul_grad_op.attr("trans_x") - assert ( - not tran_x - ), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for column parallel linear backward overlapping" - tran_y = matmul_grad_op.attr("trans_y") - assert ( - not tran_y - ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for column parallel linear backward overlapping" - - allreduce_op.dist_attr.execution_stream = ( - AutoParallelStreamType.MP_STREAM.value + + # matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape + _split_matmul_grad_to_matmul( + block, matmul_grad_id, self.op_namescope, self.dist_context ) - x = matmul_grad_op.input("X") - y = matmul_grad_op.input("Y") - out_grad = matmul_grad_op.input("Out@GRAD") - x_grad = matmul_grad_op.output("X@GRAD") - y_grad = matmul_grad_op.output("Y@GRAD") - op_role = matmul_grad_op.attr("op_role") - # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then # the matmul op before it cannot be overlapped. - var_x = block.var(x[0]) - var_out_grad = block.var(out_grad[0]) - var_y_grad = block.var(y_grad[0]) - - x_dims = var_x.shape - out_grad_dims = var_out_grad.shape - y_grad_dims = var_y_grad.shape - - assert len(x_dims) == len( - out_grad_dims - ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." - if len(x_dims) > 2: - assert ( - x_dims[0:2] == out_grad_dims[0:2] - ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." - new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) - new_out_grad_dims = [ - out_grad_dims[0] * out_grad_dims[1] - ] + list(out_grad_dims[2:]) - - # NOTE(Ruibiao): Why insert reshape op here? - # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. - # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. - # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. - new_x = self._insert_reshape_op( - block, allreduce_id + 1, x, new_x_dims, op_role + allreduce_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( + allreduce_op ) - new_out_grad = self._insert_reshape_op( - block, allreduce_id + 2, out_grad, new_out_grad_dims, op_role - ) - new_y_grad = block.create_var( - name=f"{y_grad[0]}@reshape.out", - dtype=var_y_grad.dtype, - persistable=False, - ) - self.dist_context.set_tensor_dist_attr_for_program( - new_y_grad, - self.dist_context.get_tensor_dist_attr_for_program(var_y_grad), - ) - - matmul_grad_dist_attr = ( - self.dist_context.get_op_dist_attr_for_program(matmul_grad_op) - ) - matmul_op = block._insert_op_without_sync( - index=allreduce_id + 3, - type="matmul_v2", - inputs={"X": new_x, "Y": new_out_grad}, - outputs={"Out": new_y_grad}, - attrs={ - "trans_x": True, - "trans_y": False, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, - ) - self.dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) - - self._insert_reshape_op( - block, - allreduce_id + 4, - [new_y_grad.name], - y_grad_dims, - op_role, - y_grad, - ) - - matmul_op = block._insert_op_without_sync( - index=matmul_grad_id + 1, - type="matmul_v2", - inputs={"X": out_grad, "Y": y}, - outputs={"Out": x_grad}, - attrs={ - "trans_x": False, - "trans_y": True, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, + + allreduce_op_inputs = allreduce_op.desc.input_names() + allreduce_op_outputs = allreduce_op.desc.output_names() + + allreduce_op_inputs = {name: allreduce_op.input(name) for name in allreduce_op_inputs} + allreduce_op_outputs = {name: allreduce_op.output(name) for name in allreduce_op_outputs} + + allreduce_op = block._insert_op_without_sync( + index=allreduce_id + 1, + type=allreduce_op.type, + inputs=allreduce_op_inputs, + outputs=allreduce_op_outputs, + attrs=allreduce_op.all_attrs(), ) self.dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr + allreduce_op, allreduce_op_dist_attr ) - - block._remove_op(matmul_grad_id, sync=False) + # Remove the original allreduce op + block._remove_op(allreduce_id + 5, sync=False) + block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index f1dcc8a7ffd797..ff213349de86f8 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -29,6 +29,9 @@ use_new_executor, ) from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.auto_parallel.static.utils import ( + naive_set_dist_op_attr_for_program_by_mesh_and_mapping +) __not_shape_var_type__ = [ core.VarDesc.VarType.READER, @@ -785,3 +788,178 @@ def _add_event_dependency(recorder_op, waiter_op): if recorder_op.dist_attr.event_to_record not in waiter_wait_list: waiter_wait_list.append(recorder_op.dist_attr.event_to_record) waiter_op.dist_attr.events_to_wait = waiter_wait_list + + +def _insert_reshape_op(block, index, x, shape, op_role, op_namescope="/", out=None, dist_context=None): + var_x = block.var(x[0]) + if dist_context: + x_dist_attr = dist_context.get_tensor_dist_attr_for_program(var_x) + + if out is None: + out = block.create_var( + name=f"{x[0]}@reshape.out", + dtype=var_x.dtype, + persistable=False, + ) + if dist_context: + dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) + + x_shape = block.create_var( + name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype + ) + if dist_context: + dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) + + reshape_op = block._insert_op_without_sync( + index=index, + type="reshape2", + inputs={"X": x}, + outputs={"Out": out, "XShape": x_shape}, + attrs={ + "shape": shape, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + if dist_context: + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + reshape_op, + process_mesh=x_dist_attr.process_mesh, + ref_mapping=x_dist_attr.dims_mapping, + ctx=dist_context, + chunk_id=x_dist_attr.chunk_id, + ) + + return out + + +def _split_matmul_grad_to_matmul( + block, matmul_grad_id, op_namescope="/", dist_context=None +): + ops = block.ops + matmul_grad_op = ops[matmul_grad_id] + + tran_x = matmul_grad_op.attr("trans_x") + assert ( + not tran_x + ), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for spliting matmul_grad to matmul" + tran_y = matmul_grad_op.attr("trans_y") + assert ( + not tran_y + ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for spliting matmul_grad to matmul" + + x = matmul_grad_op.input("X") + y = matmul_grad_op.input("Y") + out_grad = matmul_grad_op.input("Out@GRAD") + x_grad = matmul_grad_op.output("X@GRAD") + y_grad = matmul_grad_op.output("Y@GRAD") + op_role = matmul_grad_op.attr("op_role") + + # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). + # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass + # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then + # the matmul op before it cannot be overlapped. + var_x = block.var(x[0]) + var_out_grad = block.var(out_grad[0]) + var_y_grad = block.var(y_grad[0]) + + x_dims = var_x.shape + out_grad_dims = var_out_grad.shape + y_grad_dims = var_y_grad.shape + + assert len(x_dims) == len( + out_grad_dims + ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." + if len(x_dims) > 2: + assert ( + x_dims[0:2] == out_grad_dims[0:2] + ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." + new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) + new_out_grad_dims = [ + out_grad_dims[0] * out_grad_dims[1] + ] + list(out_grad_dims[2:]) + + # NOTE(Ruibiao): Why insert reshape op here? + # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. + # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. + # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. + new_x = _insert_reshape_op( + block, + matmul_grad_id + 1, + x, + new_x_dims, + op_role, + op_namescope=op_namescope, + dist_context=dist_context + ) + new_out_grad = _insert_reshape_op( + block, + matmul_grad_id + 2, + out_grad, + new_out_grad_dims, + op_role, + op_namescope=op_namescope, + dist_context=dist_context + ) + new_y_grad = block.create_var( + name=f"{y_grad[0]}@reshape.out", + dtype=var_y_grad.dtype, + persistable=False, + ) + + if dist_context: + dist_context.set_tensor_dist_attr_for_program( + new_y_grad, + dist_context.get_tensor_dist_attr_for_program(var_y_grad), + ) + + if dist_context: + matmul_grad_dist_attr = ( + dist_context.get_op_dist_attr_for_program(matmul_grad_op) + ) + + matmul_op = block._insert_op_without_sync( + index=matmul_grad_id + 3, + type="matmul_v2", + inputs={"X": new_x, "Y": new_out_grad}, + outputs={"Out": new_y_grad}, + attrs={ + "trans_x": True, + "trans_y": False, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + if dist_context: + dist_context.set_op_dist_attr_for_program( + matmul_op, matmul_grad_dist_attr + ) + _insert_reshape_op( + block, + matmul_grad_id + 4, + [new_y_grad.name], + y_grad_dims, + op_role, + y_grad, + op_namescope=op_namescope, + dist_context=dist_context + ) + + matmul_op = block._insert_op_without_sync( + index=matmul_grad_id + 1, + type="matmul_v2", + inputs={"X": out_grad, "Y": y}, + outputs={"Out": x_grad}, + attrs={ + "trans_x": False, + "trans_y": True, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + if dist_context: + dist_context.set_op_dist_attr_for_program( + matmul_op, matmul_grad_dist_attr + ) + + block._remove_op(matmul_grad_id, sync=False) From b4ee57d964275ae51b27ed118553d6f45eeed043 Mon Sep 17 00:00:00 2001 From: andsonder Date: Fri, 15 Mar 2024 02:28:02 +0000 Subject: [PATCH 2/6] fix --- .../allreduce_matmul_grad_overlapping.py | 63 ++++------------- .../paddle/distributed/passes/pass_utils.py | 69 ++++++++++--------- 2 files changed, 51 insertions(+), 81 deletions(-) diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index 3123d09d121299..cad7b6264eb648 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -17,10 +17,9 @@ from ..auto_parallel.static.utils import ( get_logger, - naive_set_dist_op_attr_for_program_by_mesh_and_mapping, ) from .pass_base import PassBase, register_pass -from .pass_utils import _split_matmul_grad_to_matmul, AutoParallelStreamType +from .pass_utils import _split_matmul_grad_to_matmul logger = get_logger(logging.INFO) @@ -84,44 +83,6 @@ def _get_all_matmul_grad_and_allreduce_pairs(self, block): matmul_grad_id_to_allreduce_id[i] = j return matmul_grad_id_to_allreduce_id - def _insert_reshape_op(self, block, index, x, shape, op_role, out=None): - var_x = block.var(x[0]) - x_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(var_x) - - if out is None: - out = block.create_var( - name=f"{x[0]}@reshape.out", - dtype=var_x.dtype, - persistable=False, - ) - self.dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) - - x_shape = block.create_var( - name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype - ) - self.dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) - - reshape_op = block._insert_op_without_sync( - index=index, - type="reshape2", - inputs={"X": x}, - outputs={"Out": out, "XShape": x_shape}, - attrs={ - "shape": shape, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, - ) - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - reshape_op, - process_mesh=x_dist_attr.process_mesh, - ref_mapping=x_dist_attr.dims_mapping, - ctx=self.dist_context, - chunk_id=x_dist_attr.chunk_id, - ) - - return out - def _split_matmul_grad_and_multi_streaming_allreduce( self, block, matmul_grad_id_to_allreduce_id ): @@ -153,7 +114,7 @@ def _split_matmul_grad_and_multi_streaming_allreduce( if skip_overlapping: continue - + # matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape _split_matmul_grad_to_matmul( block, matmul_grad_id, self.op_namescope, self.dist_context @@ -163,16 +124,20 @@ def _split_matmul_grad_and_multi_streaming_allreduce( # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then # the matmul op before it cannot be overlapped. - allreduce_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( - allreduce_op + allreduce_op_dist_attr = ( + self.dist_context.get_op_dist_attr_for_program(allreduce_op) ) - + allreduce_op_inputs = allreduce_op.desc.input_names() allreduce_op_outputs = allreduce_op.desc.output_names() - - allreduce_op_inputs = {name: allreduce_op.input(name) for name in allreduce_op_inputs} - allreduce_op_outputs = {name: allreduce_op.output(name) for name in allreduce_op_outputs} - + + allreduce_op_inputs = { + name: allreduce_op.input(name) for name in allreduce_op_inputs + } + allreduce_op_outputs = { + name: allreduce_op.output(name) for name in allreduce_op_outputs + } + allreduce_op = block._insert_op_without_sync( index=allreduce_id + 1, type=allreduce_op.type, @@ -185,5 +150,5 @@ def _split_matmul_grad_and_multi_streaming_allreduce( ) # Remove the original allreduce op block._remove_op(allreduce_id + 5, sync=False) - + block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index ff213349de86f8..939ad804f72fba 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -26,12 +26,10 @@ is_backward_op, is_forward_op, is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, use_new_executor, ) from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.distributed.auto_parallel.static.utils import ( - naive_set_dist_op_attr_for_program_by_mesh_and_mapping -) __not_shape_var_type__ = [ core.VarDesc.VarType.READER, @@ -790,7 +788,16 @@ def _add_event_dependency(recorder_op, waiter_op): waiter_op.dist_attr.events_to_wait = waiter_wait_list -def _insert_reshape_op(block, index, x, shape, op_role, op_namescope="/", out=None, dist_context=None): +def _insert_reshape_op( + block, + index, + x, + shape, + op_role, + out=None, + op_namescope="/", + dist_context=None, +): var_x = block.var(x[0]) if dist_context: x_dist_attr = dist_context.get_tensor_dist_attr_for_program(var_x) @@ -804,9 +811,7 @@ def _insert_reshape_op(block, index, x, shape, op_role, op_namescope="/", out=No if dist_context: dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) - x_shape = block.create_var( - name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype - ) + x_shape = block.create_var(name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype) if dist_context: dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) @@ -838,7 +843,7 @@ def _split_matmul_grad_to_matmul( ): ops = block.ops matmul_grad_op = ops[matmul_grad_id] - + tran_x = matmul_grad_op.attr("trans_x") assert ( not tran_x @@ -847,14 +852,14 @@ def _split_matmul_grad_to_matmul( assert ( not tran_y ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for spliting matmul_grad to matmul" - + x = matmul_grad_op.input("X") y = matmul_grad_op.input("Y") out_grad = matmul_grad_op.input("Out@GRAD") x_grad = matmul_grad_op.output("X@GRAD") y_grad = matmul_grad_op.output("Y@GRAD") op_role = matmul_grad_op.attr("op_role") - + # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then @@ -862,11 +867,11 @@ def _split_matmul_grad_to_matmul( var_x = block.var(x[0]) var_out_grad = block.var(out_grad[0]) var_y_grad = block.var(y_grad[0]) - + x_dims = var_x.shape out_grad_dims = var_out_grad.shape y_grad_dims = var_y_grad.shape - + assert len(x_dims) == len( out_grad_dims ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." @@ -875,47 +880,47 @@ def _split_matmul_grad_to_matmul( x_dims[0:2] == out_grad_dims[0:2] ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) - new_out_grad_dims = [ - out_grad_dims[0] * out_grad_dims[1] - ] + list(out_grad_dims[2:]) - + new_out_grad_dims = [out_grad_dims[0] * out_grad_dims[1]] + list( + out_grad_dims[2:] + ) + # NOTE(Ruibiao): Why insert reshape op here? # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. new_x = _insert_reshape_op( - block, - matmul_grad_id + 1, - x, - new_x_dims, + block, + matmul_grad_id + 1, + x, + new_x_dims, op_role, op_namescope=op_namescope, - dist_context=dist_context + dist_context=dist_context, ) new_out_grad = _insert_reshape_op( - block, - matmul_grad_id + 2, - out_grad, - new_out_grad_dims, + block, + matmul_grad_id + 2, + out_grad, + new_out_grad_dims, op_role, op_namescope=op_namescope, - dist_context=dist_context + dist_context=dist_context, ) new_y_grad = block.create_var( name=f"{y_grad[0]}@reshape.out", dtype=var_y_grad.dtype, persistable=False, ) - + if dist_context: dist_context.set_tensor_dist_attr_for_program( new_y_grad, dist_context.get_tensor_dist_attr_for_program(var_y_grad), ) - + if dist_context: - matmul_grad_dist_attr = ( - dist_context.get_op_dist_attr_for_program(matmul_grad_op) + matmul_grad_dist_attr = dist_context.get_op_dist_attr_for_program( + matmul_grad_op ) matmul_op = block._insert_op_without_sync( @@ -942,9 +947,9 @@ def _split_matmul_grad_to_matmul( op_role, y_grad, op_namescope=op_namescope, - dist_context=dist_context + dist_context=dist_context, ) - + matmul_op = block._insert_op_without_sync( index=matmul_grad_id + 1, type="matmul_v2", From 617d24857975e886e83d974d7337fc768b9724ab Mon Sep 17 00:00:00 2001 From: andsonder Date: Fri, 15 Mar 2024 16:27:27 +0000 Subject: [PATCH 3/6] apply suggestions from code review --- .../allreduce_matmul_grad_overlapping.py | 11 ++- .../paddle/distributed/passes/pass_utils.py | 71 ++++++++----------- 2 files changed, 35 insertions(+), 47 deletions(-) diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index cad7b6264eb648..1f78d04eb1a3ac 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -94,12 +94,9 @@ def _split_matmul_grad_and_multi_streaming_allreduce( matmul_grad_op = ops[matmul_grad_id] allreduce_op = ops[allreduce_id] - # NOTE(Sonder): Why move those operations to the back of matmul_v2? - # When using amp_master_grad, the cast operation is inserted after matmul_grad. - # However, when employing allreduce_matmul_grad_overlapping, the matmul_grad is - # split into two matmul operations. In this case, some operations would access - # uninitialized tensors. Therefore, we move the cast operation to the back of the - # second matmul operation to avoid this problem. + # NOTE(Sonder): When there are ops between matmul_grad and allreduce, we should check whether the + # these ops rely on the output of the intermediate ops. If so, we should not split the matmul_grad. + # Otherwise, the output of the intermediate ops will get wrong results. skip_overlapping = False moved_ops_output = [] matmul_grad_output = matmul_grad_op.output('Y@GRAD')[0] @@ -117,7 +114,7 @@ def _split_matmul_grad_and_multi_streaming_allreduce( # matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape _split_matmul_grad_to_matmul( - block, matmul_grad_id, self.op_namescope, self.dist_context + block, matmul_grad_id, self.dist_context, self.op_namescope ) # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index 939ad804f72fba..84643abea86a1b 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -794,13 +794,12 @@ def _insert_reshape_op( x, shape, op_role, + dist_context, out=None, op_namescope="/", - dist_context=None, ): var_x = block.var(x[0]) - if dist_context: - x_dist_attr = dist_context.get_tensor_dist_attr_for_program(var_x) + x_dist_attr = dist_context.get_tensor_dist_attr_for_program(var_x) if out is None: out = block.create_var( @@ -808,12 +807,10 @@ def _insert_reshape_op( dtype=var_x.dtype, persistable=False, ) - if dist_context: - dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) + dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) x_shape = block.create_var(name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype) - if dist_context: - dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) + dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) reshape_op = block._insert_op_without_sync( index=index, @@ -826,20 +823,20 @@ def _insert_reshape_op( 'op_namescope': op_namescope, }, ) - if dist_context: - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - reshape_op, - process_mesh=x_dist_attr.process_mesh, - ref_mapping=x_dist_attr.dims_mapping, - ctx=dist_context, - chunk_id=x_dist_attr.chunk_id, - ) + + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + reshape_op, + process_mesh=x_dist_attr.process_mesh, + ref_mapping=x_dist_attr.dims_mapping, + ctx=dist_context, + chunk_id=x_dist_attr.chunk_id, + ) return out def _split_matmul_grad_to_matmul( - block, matmul_grad_id, op_namescope="/", dist_context=None + block, matmul_grad_id, dist_context, op_namescope="/" ): ops = block.ops matmul_grad_op = ops[matmul_grad_id] @@ -860,10 +857,6 @@ def _split_matmul_grad_to_matmul( y_grad = matmul_grad_op.output("Y@GRAD") op_role = matmul_grad_op.attr("op_role") - # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). - # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass - # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then - # the matmul op before it cannot be overlapped. var_x = block.var(x[0]) var_out_grad = block.var(out_grad[0]) var_y_grad = block.var(y_grad[0]) @@ -894,8 +887,8 @@ def _split_matmul_grad_to_matmul( x, new_x_dims, op_role, - op_namescope=op_namescope, dist_context=dist_context, + op_namescope=op_namescope, ) new_out_grad = _insert_reshape_op( block, @@ -903,8 +896,8 @@ def _split_matmul_grad_to_matmul( out_grad, new_out_grad_dims, op_role, - op_namescope=op_namescope, dist_context=dist_context, + op_namescope=op_namescope, ) new_y_grad = block.create_var( name=f"{y_grad[0]}@reshape.out", @@ -912,16 +905,14 @@ def _split_matmul_grad_to_matmul( persistable=False, ) - if dist_context: - dist_context.set_tensor_dist_attr_for_program( - new_y_grad, - dist_context.get_tensor_dist_attr_for_program(var_y_grad), - ) + dist_context.set_tensor_dist_attr_for_program( + new_y_grad, + dist_context.get_tensor_dist_attr_for_program(var_y_grad), + ) - if dist_context: - matmul_grad_dist_attr = dist_context.get_op_dist_attr_for_program( - matmul_grad_op - ) + matmul_grad_dist_attr = dist_context.get_op_dist_attr_for_program( + matmul_grad_op + ) matmul_op = block._insert_op_without_sync( index=matmul_grad_id + 3, @@ -935,10 +926,10 @@ def _split_matmul_grad_to_matmul( 'op_namescope': op_namescope, }, ) - if dist_context: - dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) + + dist_context.set_op_dist_attr_for_program( + matmul_op, matmul_grad_dist_attr + ) _insert_reshape_op( block, matmul_grad_id + 4, @@ -946,8 +937,8 @@ def _split_matmul_grad_to_matmul( y_grad_dims, op_role, y_grad, - op_namescope=op_namescope, dist_context=dist_context, + op_namescope=op_namescope, ) matmul_op = block._insert_op_without_sync( @@ -962,9 +953,9 @@ def _split_matmul_grad_to_matmul( 'op_namescope': op_namescope, }, ) - if dist_context: - dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) + + dist_context.set_op_dist_attr_for_program( + matmul_op, matmul_grad_dist_attr + ) block._remove_op(matmul_grad_id, sync=False) From ea00f08e041247f62b8627db7895bfeec4805d66 Mon Sep 17 00:00:00 2001 From: andsonder Date: Fri, 15 Mar 2024 16:29:20 +0000 Subject: [PATCH 4/6] update --- .../passes/allreduce_matmul_grad_overlapping.py | 2 +- python/paddle/distributed/passes/pass_utils.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index 1f78d04eb1a3ac..de8efb1bf7b4a0 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -94,7 +94,7 @@ def _split_matmul_grad_and_multi_streaming_allreduce( matmul_grad_op = ops[matmul_grad_id] allreduce_op = ops[allreduce_id] - # NOTE(Sonder): When there are ops between matmul_grad and allreduce, we should check whether the + # NOTE(Sonder): When there are ops between matmul_grad and allreduce, we should check whether # these ops rely on the output of the intermediate ops. If so, we should not split the matmul_grad. # Otherwise, the output of the intermediate ops will get wrong results. skip_overlapping = False diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index 84643abea86a1b..fd167e4b7bdd8a 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -823,7 +823,7 @@ def _insert_reshape_op( 'op_namescope': op_namescope, }, ) - + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( reshape_op, process_mesh=x_dist_attr.process_mesh, @@ -927,9 +927,7 @@ def _split_matmul_grad_to_matmul( }, ) - dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) + dist_context.set_op_dist_attr_for_program(matmul_op, matmul_grad_dist_attr) _insert_reshape_op( block, matmul_grad_id + 4, @@ -954,8 +952,6 @@ def _split_matmul_grad_to_matmul( }, ) - dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) + dist_context.set_op_dist_attr_for_program(matmul_op, matmul_grad_dist_attr) block._remove_op(matmul_grad_id, sync=False) From 37e8ca033907d068dee175e178c4247f9015eb7f Mon Sep 17 00:00:00 2001 From: andsonder Date: Sat, 16 Mar 2024 03:22:39 +0000 Subject: [PATCH 5/6] fix --- python/paddle/distributed/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index fd167e4b7bdd8a..d467d5eaa8b38f 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -934,8 +934,8 @@ def _split_matmul_grad_to_matmul( [new_y_grad.name], y_grad_dims, op_role, - y_grad, dist_context=dist_context, + out=y_grad, op_namescope=op_namescope, ) From 8f4867eea359c7fb30048d75ae6803ddf5ce6fc4 Mon Sep 17 00:00:00 2001 From: andsonder Date: Sat, 16 Mar 2024 04:35:37 +0000 Subject: [PATCH 6/6] change func name --- .../distributed/passes/allreduce_matmul_grad_overlapping.py | 4 ++-- python/paddle/distributed/passes/pass_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index de8efb1bf7b4a0..e1e4514b60d24d 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -19,7 +19,7 @@ get_logger, ) from .pass_base import PassBase, register_pass -from .pass_utils import _split_matmul_grad_to_matmul +from .pass_utils import split_matmul_grad_to_matmul logger = get_logger(logging.INFO) @@ -113,7 +113,7 @@ def _split_matmul_grad_and_multi_streaming_allreduce( continue # matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape - _split_matmul_grad_to_matmul( + split_matmul_grad_to_matmul( block, matmul_grad_id, self.dist_context, self.op_namescope ) diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index d467d5eaa8b38f..a8064e90535203 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -835,7 +835,7 @@ def _insert_reshape_op( return out -def _split_matmul_grad_to_matmul( +def split_matmul_grad_to_matmul( block, matmul_grad_id, dist_context, op_namescope="/" ): ops = block.ops